Skip to content

Commit e258afe

Browse files
Fix checkC1Access to use caller's client, add AWS CLI check
checkC1Access was building a fake cobra.Command with hardcoded flags just to call cmdContext() for a client. Since awsCredentialsRun already has a client from cmdContext, pass it through instead. Also add requireAWSCLI() check before shelling out to aws, so users get a clear error message instead of an exec failure.
1 parent a26cff4 commit e258afe

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

cmd/cone/aws.go

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ Can be used directly or as an AWS credential_process:
297297
}
298298

299299
func awsCredentialsRun(cmd *cobra.Command, args []string) error {
300-
ctx, _, _, err := cmdContext(cmd)
300+
ctx, c, _, err := cmdContext(cmd)
301301
if err != nil {
302302
return err
303303
}
@@ -308,7 +308,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error {
308308

309309
profileName := args[0]
310310

311-
accessResult, err := checkC1Access(ctx, profileName)
311+
accessResult, err := checkC1Access(ctx, c, profileName)
312312
if err != nil {
313313
return fmt.Errorf("failed to check access: %w", err)
314314
}
@@ -319,7 +319,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error {
319319
}
320320

321321
// Fetch the entitlement to get its max grant duration.
322-
entitlement, err := accessResult.client.GetEntitlement(ctx, accessResult.appID, accessResult.entitlementID)
322+
entitlement, err := c.GetEntitlement(ctx, accessResult.appID, accessResult.entitlementID)
323323
if err != nil {
324324
return fmt.Errorf("failed to get entitlement details: %w", err)
325325
}
@@ -332,7 +332,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error {
332332

333333
fmt.Fprintf(os.Stderr, "No active grant for %q — submitting access request...\n", profileName)
334334

335-
grantResp, err := accessResult.client.CreateGrantTask(
335+
grantResp, err := c.CreateGrantTask(
336336
ctx,
337337
accessResult.appID,
338338
accessResult.entitlementID,
@@ -357,7 +357,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error {
357357
time.Sleep(autoRequestPollInterval)
358358
fmt.Fprintf(os.Stderr, ".")
359359

360-
taskResp, err := accessResult.client.GetTask(ctx, taskID)
360+
taskResp, err := c.GetTask(ctx, taskID)
361361
if err != nil {
362362
break
363363
}
@@ -491,6 +491,13 @@ func getSSOToken(ssoStartURL string) (string, error) {
491491
return "", fmt.Errorf("no valid SSO token found for %s", ssoStartURL)
492492
}
493493

494+
func requireAWSCLI() error {
495+
if _, err := exec.LookPath("aws"); err != nil {
496+
return fmt.Errorf("the AWS CLI is required but was not found on PATH — install it from https://aws.amazon.com/cli/")
497+
}
498+
return nil
499+
}
500+
494501
func ssoLogin() error {
495502
fmt.Fprintf(os.Stderr, "AWS SSO session expired. Logging in...\n")
496503
loginCmd := exec.Command("aws", "sso", "login", "--sso-session", "cone-sso")
@@ -519,6 +526,10 @@ func getRoleCredentials(token, accountID, roleName, ssoRegion string) ([]byte, e
519526
}
520527

521528
func getTemporaryCredentials(accountID, roleName, ssoStartURL, ssoRegion string) (*AWSCredentials, error) {
529+
if err := requireAWSCLI(); err != nil {
530+
return nil, err
531+
}
532+
522533
token, err := getSSOToken(ssoStartURL)
523534
if err != nil {
524535
if loginErr := ssoLogin(); loginErr != nil {
@@ -576,27 +587,9 @@ type accessCheckResult struct {
576587
appID string
577588
entitlementID string
578589
userID string
579-
client client.C1Client
580590
}
581591

582-
func checkC1Access(ctx context.Context, profileName string) (*accessCheckResult, error) {
583-
// Build a minimal command to get a client context.
584-
tempCmd := &cobra.Command{Use: "temp"}
585-
tempCmd.PersistentFlags().StringP("profile", "p", "default", "")
586-
tempCmd.PersistentFlags().BoolP("non-interactive", "i", false, "")
587-
tempCmd.PersistentFlags().String("client-id", "", "")
588-
tempCmd.PersistentFlags().String("client-secret", "", "")
589-
tempCmd.PersistentFlags().String("api-endpoint", "", "")
590-
tempCmd.PersistentFlags().StringP("output", "o", "table", "")
591-
tempCmd.PersistentFlags().Bool("debug", false, "")
592-
tempCmd.PersistentFlags().String("log-level", "", "")
593-
tempCmd.SetContext(ctx)
594-
595-
_, c, _, err := cmdContext(tempCmd)
596-
if err != nil {
597-
return nil, err
598-
}
599-
592+
func checkC1Access(ctx context.Context, c client.C1Client, profileName string) (*accessCheckResult, error) {
600593
userInfo, err := c.AuthIntrospect(ctx)
601594
if err != nil {
602595
return nil, err
@@ -613,7 +606,6 @@ func checkC1Access(ctx context.Context, profileName string) (*accessCheckResult,
613606

614607
result := &accessCheckResult{
615608
userID: userID,
616-
client: c,
617609
}
618610

619611
for _, ent := range entitlements {

0 commit comments

Comments
 (0)