From 049f93d5143d59bfa59658db2b350c59e72d30ed Mon Sep 17 00:00:00 2001 From: Fletcher Woodruff Date: Fri, 17 Oct 2025 13:51:01 -0700 Subject: [PATCH] fix: allow providing unused ServiceAccountToken When ecr-credential-provider is configured to use ServiceAccountTokens for fetching ECR credentials, it should allow ignoring those credentials if the service account doesn't have the corresponding annotations. Update the provider to try to use a ServiceAccountToken and fall back to standard local credentials otherwise. --- cmd/ecr-credential-provider/main.go | 31 +++++++++----------- cmd/ecr-credential-provider/main_test.go | 37 ++++++++++++++++-------- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/cmd/ecr-credential-provider/main.go b/cmd/ecr-credential-provider/main.go index aa21bd7d72..60f7ded6cf 100644 --- a/cmd/ecr-credential-provider/main.go +++ b/cmd/ecr-credential-provider/main.go @@ -179,15 +179,18 @@ func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, i }, nil } -func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) (aws.CredentialsProvider, error) { - var err error +func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) aws.CredentialsProvider { + if request.ServiceAccountToken == "" { + return nil + } arn, ok := request.ServiceAccountAnnotations["eks.amazonaws.com/ecr-role-arn"] if !ok { arn = os.Getenv("AWS_ECR_ROLE_ARN") } if arn == "" { - return nil, errors.New("no arn provided, cannot assume role using ServiceAccountToken") + klog.Info("no arn provided, cannot assume role using ServiceAccountToken") + return nil } if e.sts == nil { @@ -195,10 +198,12 @@ func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.Cr if imageHost != ecrPublicHost { region = parseRegionFromECRPrivateHost(imageHost) } - e.sts, err = stsProvider(ctx, region) - } - if err != nil { - return nil, err + sts, err := stsProvider(ctx, region) + if err != nil { + klog.Errorf("failed to create sts client, cannot assume role: %v", err) + return nil + } + e.sts = sts } return aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { @@ -215,8 +220,7 @@ func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.Cr SecretAccessKey: *assumeOutput.Credentials.SecretAccessKey, SessionToken: *assumeOutput.Credentials.SessionToken, }, nil - }), - nil + }) } func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialProviderRequest, args []string) (*v1.CredentialProviderResponse, error) { @@ -232,14 +236,7 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialPr return nil, err } - var credentialsProvider aws.CredentialsProvider = nil - if request.ServiceAccountToken != "" { - credentialsProvider, err = e.buildCredentialsProvider(ctx, request, imageHost) - if err != nil { - return nil, err - } - } - + credentialsProvider := e.buildCredentialsProvider(ctx, request, imageHost) if imageHost == ecrPublicHost { var optFns = []func(*ecrpublic.Options){} if credentialsProvider != nil { diff --git a/cmd/ecr-credential-provider/main_test.go b/cmd/ecr-credential-provider/main_test.go index 1e4f6599cc..6c0e16fa39 100644 --- a/cmd/ecr-credential-provider/main_test.go +++ b/cmd/ecr-credential-provider/main_test.go @@ -48,7 +48,10 @@ func (m *MockedECR) GetAuthorizationToken(ctx context.Context, params *ecr.GetAu fn(&opts) } if opts.Credentials != nil { - opts.Credentials.Retrieve(ctx) + _, err := opts.Credentials.Retrieve(ctx) + if err != nil { + return nil, err + } } if args.Get(1) != nil { @@ -64,6 +67,18 @@ type MockedECRPublic struct { func (m *MockedECRPublic) GetAuthorizationToken(ctx context.Context, params *ecrpublic.GetAuthorizationTokenInput, optFns ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error) { args := m.Called(ctx, params) + + opts := ecrpublic.Options{} + for _, fn := range optFns { + fn(&opts) + } + if opts.Credentials != nil { + _, err := opts.Credentials.Retrieve(ctx) + if err != nil { + return nil, err + } + } + if args.Get(1) != nil { return args.Get(0).(*ecrpublic.GetAuthorizationTokenOutput), args.Get(1).(error) } @@ -245,7 +260,6 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) { }, }, response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"), - expectedError: errors.New("no arn provided, cannot assume role using ServiceAccountToken"), }, { name: "assume error", @@ -254,7 +268,7 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) { getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil), assumeRoleWithWebIdentityError: errors.New("injected error"), response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"), - expectedError: errors.New("injected error"), + expectedError: errors.New("failed to assume role: injected error"), }, } for _, testcase := range testcases { @@ -271,15 +285,15 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) { } mockSTS.On("AssumeRoleWithWebIdentity", mock.Anything, &expectedInput).Return(testcase.assumeRoleWithWebIdentityOutput, testcase.assumeRoleWithWebIdentityError) creds, err := p.GetCredentials(context.TODO(), testcase.request, testcase.args) - if err != nil { - if testcase.expectedError == nil { - t.Fatalf("got unexpected error %s", err.Error()) - - } + if err == nil && testcase.expectedError != nil { + t.Fatalf("expected error '%s' but got no error", testcase.expectedError) + } + if err != nil && testcase.expectedError == nil { + t.Fatalf("got unexpected error %s", err.Error()) + } - if testcase.expectedError.Error() != err.Error() { - t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) - } + if err != nil && testcase.expectedError.Error() != err.Error() { + t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) } if testcase.expectedError == nil { @@ -382,7 +396,6 @@ func Test_GetCredentials_Public(t *testing.T) { mockECRPublic.On("GetAuthorizationToken", mock.Anything, mock.Anything).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError) creds, err := p.GetCredentials(context.TODO(), &v1.CredentialProviderRequest{Image: testcase.image}, testcase.args) - if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) { t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error()) }