@@ -10,45 +10,47 @@ import (
1010
1111 "github.com/aws/aws-sdk-go-v2/aws"
1212 "github.com/aws/aws-sdk-go-v2/aws/arn"
13- "github.com/aws/aws-sdk-go-v2/config"
14- awsCreds "github.com/aws/aws-sdk-go-v2/credentials"
15- "github.com/aws/aws-sdk-go-v2/service/eks"
1613 "github.com/aws/aws-sdk-go-v2/service/sts"
1714 "github.com/aws/aws-sdk-go-v2/service/sts/types"
1815 "github.com/golang-jwt/jwt/v5"
1916 "github.com/sirupsen/logrus"
2017 "go.amzn.com/eks/eks-pod-identity-agent/internal/middleware/logger"
2118 "go.amzn.com/eks/eks-pod-identity-agent/pkg/credentials"
19+ "go.amzn.com/eks/eks-pod-identity-agent/pkg/extensions/chainrole/ekspodidentities"
20+ "go.amzn.com/eks/eks-pod-identity-agent/pkg/extensions/chainrole/serviceaccount"
2221)
2322
2423const (
2524 assumeRoleAnnotationPrefix = "assume-role.ekspia.go.amzn.com/"
2625 sessionTagRoleAnnotationPrefix = assumeRoleAnnotationPrefix + "session-tag/"
26+ // service account annotations doesn't support more than one "/"
27+ sessionTagRoleAnnotationPrefix2 = assumeRoleAnnotationPrefix + "session-tag-"
2728)
2829
2930type (
3031 roleAssumer interface {
3132 AssumeRole (ctx context.Context , params * sts.AssumeRoleInput , optFns ... func (* sts.Options )) (* sts.AssumeRoleOutput , error )
3233 }
3334
34- sessionConfigFunc func (ctx context.Context , awsCfg aws.Config , clusterName string , associationID string ) (* sts.AssumeRoleInput , error )
35+ sessionConfigRetriever interface {
36+ GetSessionConfigMap (ctx context.Context , request * credentials.EksCredentialsRequest ) (map [string ]string , error )
37+ }
3538
3639 CredentialRetriever struct {
3740 delegate credentials.CredentialRetriever
3841 jwtParser * jwt.Parser
3942 roleAssumer roleAssumer
40- getSessionConfig sessionConfigFunc
43+ sessionConfigRetriever sessionConfigRetriever
4144 reNamespaceFilter * regexp.Regexp
4245 reServiceAccountFilter * regexp.Regexp
4346 }
4447)
4548
4649func NewCredentialsRetriever (awsCfg aws.Config , eksCredentialsRetriever credentials.CredentialRetriever ) * CredentialRetriever {
4750 cr := & CredentialRetriever {
48- delegate : eksCredentialsRetriever ,
49- jwtParser : jwt .NewParser (),
50- roleAssumer : sts .NewFromConfig (awsCfg ),
51- getSessionConfig : getSessionConfigurationFromEKSPodIdentityTags ,
51+ delegate : eksCredentialsRetriever ,
52+ jwtParser : jwt .NewParser (),
53+ roleAssumer : sts .NewFromConfig (awsCfg ),
5254 }
5355
5456 log := logger .FromContext (context .TODO ()).WithField ("extension" , "chainrole" )
@@ -69,75 +71,46 @@ func NewCredentialsRetriever(awsCfg aws.Config, eksCredentialsRetriever credenti
6971 log .Info ("Enabled extension..." )
7072 }
7173
72- return cr
73- }
74-
75- func getSessionConfigurationFromEKSPodIdentityTags (ctx context.Context , awsCfg aws.Config , clusterName , associationID string ) (* sts.AssumeRoleInput , error ) {
76- // Describe pod identity association to get tags
77- podIdentityAssociation , err := eks .NewFromConfig (awsCfg ).DescribePodIdentityAssociation (ctx ,
78- & eks.DescribePodIdentityAssociationInput {
79- AssociationId : aws .String (associationID ),
80- ClusterName : aws .String (clusterName ),
81- })
82- if err != nil {
83- return nil , fmt .Errorf ("error describing pod identity association %s/%s: %w" , clusterName , associationID , err )
74+ switch sessionConfigSourceVal {
75+ case eksPodIdentityAssociationTags :
76+ cr .sessionConfigRetriever = ekspodidentities .NewSessionConfigRetriever (eksCredentialsRetriever )
77+ case serviceAccountAnnotations :
78+ cr .sessionConfigRetriever = serviceaccount .NewSessionConfigRetriever ()
79+ default :
8480 }
8581
86- assumeRoleInput := tagsToSTSAssumeRole (podIdentityAssociation .Association .Tags )
87-
88- if assumeRoleInput .RoleArn == nil {
89- return nil , fmt .Errorf ("couldn't get assume role arn from pod identity association tags %v" , podIdentityAssociation .Association .Tags )
90- }
91-
92- return assumeRoleInput , nil
82+ return cr
9383}
9484
9585func (c * CredentialRetriever ) GetIamCredentials (ctx context.Context , request * credentials.EksCredentialsRequest ) (
9686 * credentials.EksCredentialsResponse , credentials.ResponseMetadata , error ) {
9787 log := logger .FromContext (ctx ).WithField ("extension" , "chainrole" )
9888
99- // Get AWS EKS Pod Identity credentials as usual
100- iamCredentials , responseMetadata , err := c .delegate .GetIamCredentials (ctx , request )
101- if err != nil {
102- return nil , nil , err
103- }
104-
10589 // Get Namespace and ServiceAccount names from JWT token
10690 ns , sa , err := c .serviceAccountFromJWT (request .ServiceAccountToken )
10791 if err != nil {
10892 return nil , nil , fmt .Errorf ("error parsing JWT token: %w" , err )
10993 }
11094
111- log = log .WithFields (logrus.Fields {
112- "namespace" : ns ,
113- "serviceaccount" : sa ,
114- "cluster-name" : request .ClusterName ,
115- "association-id" : responseMetadata .AssociationId (),
116- })
117-
11895 // Check if Namespace/ServiceAccount filters configured
11996 // and do not proceed with role chaining if they don't match
12097 if ! c .isEnabledFor (ns , sa ) {
12198 log .Debug ("namespace/serviceaccount do not match ChainRole filter. Skipping role chaining" )
122- return iamCredentials , responseMetadata , nil
99+ return c . delegate . GetIamCredentials ( ctx , request )
123100 }
124101
125- // Assume eks pod identity credentials
126- podIdentityCfg , err := config .LoadDefaultConfig (context .TODO (), config .WithCredentialsProvider (
127- awsCreds .NewStaticCredentialsProvider (iamCredentials .AccessKeyId , iamCredentials .SecretAccessKey , iamCredentials .Token ),
128- ))
129- if err != nil {
130- return nil , nil , fmt .Errorf ("error loading pod identity credentials: %w" , err )
131- }
102+ log = log .WithFields (logrus.Fields {
103+ "namespace" : ns ,
104+ "serviceaccount" : sa ,
105+ "cluster-name" : request .ClusterName ,
106+ })
132107
133- // Assume new session based on the configurations provided in tags
134- // session is assumed based on the IRSA credentials and NOT EKS Identity credentials
135- // this is because EKS Identity credentials adds bunch of default tags
136- // leaving no space for our custom tags https://github.com/aws/containers-roadmap/issues/2413
137- assumeRoleInput , err := c .getSessionConfig (ctx , podIdentityCfg , request .ClusterName , responseMetadata .AssociationId ())
108+ sessionConfigMap , err := c .sessionConfigRetriever .GetSessionConfigMap (ctx , request )
138109 if err != nil {
139- return nil , nil , fmt . Errorf ( "error getting session configuration: %w" , err )
110+ return nil , nil , err
140111 }
112+
113+ assumeRoleInput := tagsToSTSAssumeRole (sessionConfigMap )
141114 assumeRoleOutput , err := c .roleAssumer .AssumeRole (ctx , assumeRoleInput )
142115 if err != nil {
143116 return nil , nil , fmt .Errorf ("error assuming role %s: %w" , * assumeRoleInput .RoleArn , err )
@@ -154,7 +127,7 @@ func (c *CredentialRetriever) GetIamCredentials(ctx context.Context, request *cr
154127 return nil , nil , fmt .Errorf ("error formatting IAM credentials: %w" , err )
155128 }
156129
157- return assumedCredentials , responseMetadata , nil
130+ return assumedCredentials , nil , nil
158131}
159132
160133func (c * CredentialRetriever ) isEnabledFor (namespace , serviceAccount string ) bool {
@@ -196,8 +169,9 @@ func tagsToSTSAssumeRole(tags map[string]string) *sts.AssumeRoleInput {
196169 assumeRoleParams .DurationSeconds = aws .Int32 (int32 (duration .Seconds ()))
197170 }
198171
199- if strings .HasPrefix (key , sessionTagRoleAnnotationPrefix ) {
172+ if strings .HasPrefix (key , sessionTagRoleAnnotationPrefix ) || strings . HasPrefix ( key , sessionTagRoleAnnotationPrefix2 ) {
200173 tagKey := strings .TrimPrefix (key , sessionTagRoleAnnotationPrefix )
174+ tagKey = strings .TrimPrefix (tagKey , sessionTagRoleAnnotationPrefix2 )
201175
202176 assumeRoleParams .Tags = append (assumeRoleParams .Tags , types.Tag {
203177 Key : aws .String (tagKey ),
@@ -228,26 +202,15 @@ func formatIAMCredentials(o *sts.AssumeRoleOutput) (*credentials.EksCredentialsR
228202 }, nil
229203}
230204
231- func (c * CredentialRetriever ) serviceAccountFromJWT (token string ) (string , string , error ) {
232- parsedToken , _ , err := c . jwtParser . ParseUnverified (token , & jwt. RegisteredClaims {} )
205+ func (c * CredentialRetriever ) serviceAccountFromJWT (token string ) (ns string , sa string , err error ) {
206+ claims , subject , err := serviceaccount . ServiceAccountFromJWT (token )
233207 if err != nil {
234208 return "" , "" , fmt .Errorf ("error parsing JWT token: %w" , err )
235209 }
236210
237- subject , err := parsedToken .Claims .GetSubject ()
238- if err != nil {
239- return "" , "" , fmt .Errorf ("error reading JWT token subject: %w" , err )
240- }
241-
242- // subject is in the format: system:serviceaccount:<namespace>:<service_account>
243- if ! strings .HasPrefix (subject , "system:serviceaccount:" ) {
244- return "" , "" , errors .New ("JWT token claim subject doesn't start with 'system:serviceaccount:'" )
245- }
246-
247- subjectParts := strings .Split (subject , ":" )
248- if len (subjectParts ) < 4 {
249- return "" , "" , errors .New ("invalid JWT token claim subject" )
211+ if claims != nil && claims .Namespace != "" && claims .ServiceAccount .Name != "" {
212+ return claims .Namespace , claims .ServiceAccount .Name , nil
250213 }
251214
252- return subjectParts [ 2 ], subjectParts [ 3 ], nil
215+ return serviceaccount . ServiceAccountFromJWTSubject ( subject )
253216}
0 commit comments