Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions cmd/ecr-credential-provider/main.go
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far as I know, you can only assign 1 IAM role to a pod through IRSA. However, it's unclear whether SA token authentication requires IRSA. If the ecr-role-arn and role-arn have to be the same, then we should update this PR to only use role-arn as that is key that IRSA uses.

Copy link
Contributor Author

@fletcherw fletcherw Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little new to this but this is my understanding

Each pod has exactly one service account associated with it. By default, if you annotate that service account with eks.amazonaws.com/role-arn, the pod identity agent webhook will get a service account token for that node account, assume that role and pass those credentials into the pod's containers.

Separately, before a pod is created, if there is a CredentialProvider that has matchImages matching the image URL and that configures tokenAttributes, the kubelet will also create a Service Account token and pass that to the credential provider.

I don't see why the IAM role that the credential provider assumes would have to match the role that the pod identity webhook assumes, though I'm sure that for many people they will be the same.

Original file line number Diff line number Diff line change
Expand Up @@ -179,26 +179,31 @@ func (e *ecrPlugin) getPrivateCredsData(ctx context.Context, imageHost string, i
}, nil
}

func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) (aws.CredentialsProvider, error) {
var err error
func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.CredentialProviderRequest, imageHost string) aws.CredentialsProvider {
if request.ServiceAccountToken == "" {
return nil
}

arn, ok := request.ServiceAccountAnnotations["eks.amazonaws.com/ecr-role-arn"]
if !ok {
arn = os.Getenv("AWS_ECR_ROLE_ARN")
}
if arn == "" {
return nil, errors.New("no arn provided, cannot assume role using ServiceAccountToken")
klog.Info("no arn provided, cannot assume role using ServiceAccountToken")
return nil
}

if e.sts == nil {
region := ""
if imageHost != ecrPublicHost {
region = parseRegionFromECRPrivateHost(imageHost)
}
e.sts, err = stsProvider(ctx, region)
}
if err != nil {
return nil, err
sts, err := stsProvider(ctx, region)
if err != nil {
klog.Errorf("failed to create sts client, cannot assume role: %v", err)
return nil
}
e.sts = sts
}

return aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) {
Expand All @@ -215,8 +220,7 @@ func (e *ecrPlugin) buildCredentialsProvider(ctx context.Context, request *v1.Cr
SecretAccessKey: *assumeOutput.Credentials.SecretAccessKey,
SessionToken: *assumeOutput.Credentials.SessionToken,
}, nil
}),
nil
})
}

func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialProviderRequest, args []string) (*v1.CredentialProviderResponse, error) {
Expand All @@ -232,14 +236,7 @@ func (e *ecrPlugin) GetCredentials(ctx context.Context, request *v1.CredentialPr
return nil, err
}

var credentialsProvider aws.CredentialsProvider = nil
if request.ServiceAccountToken != "" {
credentialsProvider, err = e.buildCredentialsProvider(ctx, request, imageHost)
if err != nil {
return nil, err
}
}

credentialsProvider := e.buildCredentialsProvider(ctx, request, imageHost)
if imageHost == ecrPublicHost {
var optFns = []func(*ecrpublic.Options){}
if credentialsProvider != nil {
Expand Down
37 changes: 25 additions & 12 deletions cmd/ecr-credential-provider/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func (m *MockedECR) GetAuthorizationToken(ctx context.Context, params *ecr.GetAu
fn(&opts)
}
if opts.Credentials != nil {
opts.Credentials.Retrieve(ctx)
_, err := opts.Credentials.Retrieve(ctx)
if err != nil {
return nil, err
}
}

if args.Get(1) != nil {
Expand All @@ -64,6 +67,18 @@ type MockedECRPublic struct {

func (m *MockedECRPublic) GetAuthorizationToken(ctx context.Context, params *ecrpublic.GetAuthorizationTokenInput, optFns ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error) {
args := m.Called(ctx, params)

opts := ecrpublic.Options{}
for _, fn := range optFns {
fn(&opts)
}
if opts.Credentials != nil {
_, err := opts.Credentials.Retrieve(ctx)
if err != nil {
return nil, err
}
}

if args.Get(1) != nil {
return args.Get(0).(*ecrpublic.GetAuthorizationTokenOutput), args.Get(1).(error)
}
Expand Down Expand Up @@ -245,7 +260,6 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) {
},
},
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
expectedError: errors.New("no arn provided, cannot assume role using ServiceAccountToken"),
},
{
name: "assume error",
Expand All @@ -254,7 +268,7 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) {
getAuthorizationTokenOutput: generatePrivateGetAuthorizationTokenOutput("user", "pass", "", nil),
assumeRoleWithWebIdentityError: errors.New("injected error"),
response: generateResponse("123456789123.dkr.ecr.us-west-2.amazonaws.com", "user", "pass"),
expectedError: errors.New("injected error"),
expectedError: errors.New("failed to assume role: injected error"),
},
}
for _, testcase := range testcases {
Expand All @@ -271,15 +285,15 @@ func Test_GetCredentials_PrivateForServiceAccount(t *testing.T) {
}
mockSTS.On("AssumeRoleWithWebIdentity", mock.Anything, &expectedInput).Return(testcase.assumeRoleWithWebIdentityOutput, testcase.assumeRoleWithWebIdentityError)
creds, err := p.GetCredentials(context.TODO(), testcase.request, testcase.args)
if err != nil {
if testcase.expectedError == nil {
t.Fatalf("got unexpected error %s", err.Error())

}
if err == nil && testcase.expectedError != nil {
t.Fatalf("expected error '%s' but got no error", testcase.expectedError)
}
if err != nil && testcase.expectedError == nil {
t.Fatalf("got unexpected error %s", err.Error())
}

if testcase.expectedError.Error() != err.Error() {
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
}
if err != nil && testcase.expectedError.Error() != err.Error() {
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
}

if testcase.expectedError == nil {
Expand Down Expand Up @@ -382,7 +396,6 @@ func Test_GetCredentials_Public(t *testing.T) {
mockECRPublic.On("GetAuthorizationToken", mock.Anything, mock.Anything).Return(testcase.getAuthorizationTokenOutput, testcase.getAuthorizationTokenError)

creds, err := p.GetCredentials(context.TODO(), &v1.CredentialProviderRequest{Image: testcase.image}, testcase.args)

if testcase.expectedError != nil && (testcase.expectedError.Error() != err.Error()) {
t.Fatalf("expected %s, got %s", testcase.expectedError.Error(), err.Error())
}
Expand Down