Skip to content

Commit 8dd1286

Browse files
committed
Add optional GetFederatedToken
1 parent a4a4307 commit 8dd1286

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

sdk/azidentity/workload_identity.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ const credNameWorkloadIdentity = "WorkloadIdentityCredential"
2323
//
2424
// [Azure Kubernetes Service documentation]: https://learn.microsoft.com/azure/aks/workload-identity-overview
2525
type WorkloadIdentityCredential struct {
26-
assertion, file string
27-
cred *ClientAssertionCredential
28-
expires time.Time
29-
mtx *sync.RWMutex
26+
assertion, file string
27+
getAssertionOverride func(ctx context.Context) (string, error)
28+
cred *ClientAssertionCredential
29+
expires time.Time
30+
mtx *sync.RWMutex
3031
}
3132

3233
type WorkloadIdentityAzureProxyOptions = customtokenproxy.Options
@@ -73,7 +74,12 @@ type WorkloadIdentityCredentialOptions struct {
7374

7475
// TokenFilePath is the path of a file containing a Kubernetes service account token. Defaults to the value of the
7576
// environment variable AZURE_FEDERATED_TOKEN_FILE.
77+
// If GetFederatedToken is set, this field is ignored.
7678
TokenFilePath string
79+
80+
// GetFederatedToken defines an optional func to get the Kubernetes service account token.
81+
// If this function is set, it will be used to get the token instead of reading it from TokenFilePath.
82+
GetFederatedToken func(ctx context.Context) (string, error)
7783
}
7884

7985
// NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read
@@ -90,9 +96,9 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (
9096
}
9197
}
9298
file := options.TokenFilePath
93-
if file == "" {
99+
if file == "" && options.GetFederatedToken == nil {
94100
if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok {
95-
return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options")
101+
return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath or GetFederatedToken in the options")
96102
}
97103
}
98104
tenantID := options.TenantID
@@ -102,7 +108,11 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (
102108
}
103109
}
104110

105-
w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
111+
w := WorkloadIdentityCredential{
112+
file: file,
113+
getAssertionOverride: options.GetFederatedToken,
114+
mtx: &sync.RWMutex{},
115+
}
106116
caco := &ClientAssertionCredentialOptions{
107117
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
108118
Cache: options.Cache,
@@ -139,7 +149,11 @@ func (w *WorkloadIdentityCredential) GetToken(ctx context.Context, opts policy.T
139149

140150
// getAssertion returns the specified file's content, which is expected to be a Kubernetes service account token.
141151
// Kubernetes is responsible for updating the file as service account tokens expire.
142-
func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, error) {
152+
func (w *WorkloadIdentityCredential) getAssertion(ctx context.Context) (string, error) {
153+
if w.getAssertionOverride != nil {
154+
return w.getAssertionOverride(ctx)
155+
}
156+
143157
w.mtx.RLock()
144158
if w.expires.Before(time.Now()) {
145159
// ensure only one goroutine at a time updates the assertion

sdk/azidentity/workload_identity_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,34 @@ func (c *customTokenRequestPolicyFlowCheck) Validate(t testing.TB, req *http.Req
388388
require.Equal(t, c.requiredHeaderValue, req.Header.Get(c.requiredHeaderKey))
389389
}
390390

391+
func TestWorkloadIdentityCredential_GetFederatedTokenOverride(t *testing.T) {
392+
const overrideToken = "override-token"
393+
called := new(atomic.Int32)
394+
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
395+
ClientID: fakeClientID,
396+
TenantID: fakeTenantID,
397+
ClientOptions: policy.ClientOptions{
398+
Transport: &mockSTS{
399+
tokenRequestCallback: func(req *http.Request) *http.Response {
400+
require.NoError(t, req.ParseForm())
401+
require.Equal(t, overrideToken, req.PostForm.Get("client_assertion"))
402+
called.Add(1)
403+
return nil
404+
},
405+
},
406+
},
407+
GetFederatedToken: func(ctx context.Context) (string, error) {
408+
return overrideToken, nil
409+
},
410+
})
411+
require.NoError(t, err)
412+
413+
tk, err := cred.GetToken(context.Background(), testTRO)
414+
require.NoError(t, err)
415+
require.Equal(t, tokenValue, tk.Token)
416+
require.Equal(t, int32(1), called.Load())
417+
}
418+
391419
func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T) {
392420
tempFile := filepath.Join(t.TempDir(), "test-workload-token-file")
393421
if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil {

0 commit comments

Comments
 (0)