diff --git a/cmd/ecr-credential-provider/main.go b/cmd/ecr-credential-provider/main.go index aa21bd7d72..60f7ded6cf 100644 --- a/cmd/ecr-credential-provider/main.go +++ b/cmd/ecr-credential-provider/main.go @@ -179,15 +179,18 @@ func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, i }, nil } -func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) (aws.CredentialsProvider, error) { - var err error +func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) aws.CredentialsProvider { + if request.ServiceAccountToken == "" { + return nil + } arn, ok := request.ServiceAccountAnnotations["eks.amazonaws.com/ecr-role-arn"] if !ok { arn = os.Getenv("AWS_ECR_ROLE_ARN") } if arn == "" { - return nil, errors.New("no arn provided, cannot assume role using ServiceAccountToken") + klog.Info("no arn provided, cannot assume role using ServiceAccountToken") + return nil } if e.sts == nil { @@ -195,10 +198,12 @@ func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.Cr if imageHost != ecrPublicHost { region = parseRegionFromECRPrivateHost(imageHost) } - e.sts, err = stsProvider(ctx, region) - } - if err != nil { - return nil, err + sts, err := stsProvider(ctx, region) + if err != nil { + klog.Errorf("failed to create sts client, cannot assume role: %v", err) + return nil + } + e.sts = sts } return aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { @@ -215,8 +220,7 @@ func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.Cr SecretAccessKey: *assumeOutput.Credentials.SecretAccessKey, SessionToken: *assumeOutput.Credentials.SessionToken, }, nil - }), - nil + }) } func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialProviderRequest, args []string) (*v1.CredentialProviderResponse, error) { @@ -232,14 +236,7 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialPr return nil, err } - var credentialsProvider aws.CredentialsProvider = nil - if request.ServiceAccountToken != "" { - credentialsProvider, err = e.buildCredentialsProvider(ctx, request, imageHost) - if err != nil { - return nil, err - } - } - + credentialsProvider := e.buildCredentialsProvider(ctx, request, imageHost) if imageHost == ecrPublicHost { var optFns = []func(*ecrpublic.Options){} if credentialsProvider != nil { diff --git a/cmd/ecr-credential-provider/main_test.go b/cmd/ecr-credential-provider/main_test.go index 1e4f6599cc..6c0e16fa39 100644 --- a/cmd/ecr-credential-provider/main_test.go +++ b/cmd/ecr-credential-provider/main_test.go @@ -48,7 +48,10 @@ func (m *MockedECR) GetAuthorizationToken(ctx context.Context, params *ecr.GetAu fn(&opts) } if opts.Credentials != nil { - opts.Credentials.Retrieve(ctx) + _, err := opts.Credentials.Retrieve(ctx) + if err != nil { + return nil, err + } } if args.Get(1) != nil { @@ -64,6 +67,18 @@ type MockedECRPublic struct { func (m *MockedECRPublic) GetAuthorizationToken(ctx context.Context, params *ecrpublic.GetAuthorizationTokenInput, optFns ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error) { args := m.Called(ctx, params) + + opts := ecrpublic.Options{} + for _, fn := range optFns { + fn(&opts) + } + if opts.Credentials != nil { + _, err := opts.Credentials.Retrieve(ctx) + if err != nil { + return nil, err + } + } + if args.Get(1) != nil { return args.Get(0).(*ecrpublic.GetAuthorizationTokenOutput), args.Get(1).(error) } @@ -245,7 +260,6 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) { }, }, response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"), - expectedError: errors.New("no arn provided, cannot assume role using ServiceAccountToken"), }, { name: "assume error", @@ -254,7 +268,7 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) { getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil), assumeRoleWithWebIdentityError: errors.New("injected error"), response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"), - expectedError: errors.New("injected error"), + expectedError: errors.New("failed to assume role: injected error"), }, } for _, testcase := range testcases { @@ -271,15 +285,15 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) { } mockSTS.On("AssumeRoleWithWebIdentity", mock.Anything, &expectedInput).Return(testcase.assumeRoleWithWebIdentityOutput, testcase.assumeRoleWithWebIdentityError) creds, err := p.GetCredentials(context.TODO(), testcase.request, testcase.args) - if err != nil { - if testcase.expectedError == nil { - t.Fatalf("got unexpected error %s", err.Error()) - - } + if err == nil && testcase.expectedError != nil { + t.Fatalf("expected error '%s' but got no error", testcase.expectedError) + } + if err != nil && testcase.expectedError == nil { + t.Fatalf("got unexpected error %s", err.Error()) + } - if testcase.expectedError.Error() != err.Error() { - t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) - } + if err != nil && testcase.expectedError.Error() != err.Error() { + t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) } if testcase.expectedError == nil { @@ -382,7 +396,6 @@ func Test_GetCredentials_Public(t *testing.T) { mockECRPublic.On("GetAuthorizationToken", mock.Anything, mock.Anything).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError) creds, err := p.GetCredentials(context.TODO(), &v1.CredentialProviderRequest{Image: testcase.image}, testcase.args) - if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) { t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) }