From a4a43075897d655f68128986e2061be3e657a34a Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 25 Nov 2025 09:35:04 +1100 Subject: [PATCH 1/8] Allow passing token proxy options from options --- sdk/azidentity/go.work.sum | 8 + .../customtokenproxy/configuration.go | 139 ++++++++ .../customtokenproxy/configuration_test.go | 300 ++++++++++++++++++ .../internal/customtokenproxy/transport.go | 86 ----- .../customtokenproxy/transport_test.go | 229 ------------- sdk/azidentity/workload_identity.go | 15 +- sdk/azidentity/workload_identity_test.go | 52 +++ 7 files changed, 511 insertions(+), 318 deletions(-) create mode 100644 sdk/azidentity/go.work.sum create mode 100644 sdk/azidentity/internal/customtokenproxy/configuration.go create mode 100644 sdk/azidentity/internal/customtokenproxy/configuration_test.go diff --git a/sdk/azidentity/go.work.sum b/sdk/azidentity/go.work.sum new file mode 100644 index 000000000000..7a078fbab348 --- /dev/null +++ b/sdk/azidentity/go.work.sum @@ -0,0 +1,8 @@ +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/keybase/dbus v0.0.0-20220506165403-5aa21ea2c23a/go.mod h1:YPNKjjE7Ubp9dTbnWvsP3HT+hYnY6TfXzubYTBeUxc8= +github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= diff --git a/sdk/azidentity/internal/customtokenproxy/configuration.go b/sdk/azidentity/internal/customtokenproxy/configuration.go new file mode 100644 index 000000000000..c638671b0458 --- /dev/null +++ b/sdk/azidentity/internal/customtokenproxy/configuration.go @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package customtokenproxy + +import ( + "errors" + "fmt" + "net/url" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +const ( + AzureKubernetesCAData = "AZURE_KUBERNETES_CA_DATA" + AzureKubernetesCAFile = "AZURE_KUBERNETES_CA_FILE" + AzureKubernetesSNIName = "AZURE_KUBERNETES_SNI_NAME" + + AzureKubernetesTokenProxy = "AZURE_KUBERNETES_TOKEN_PROXY" +) + +// Options contains optional parameters for custom token proxy configuration. +type Options struct { + // AzureKubernetesCAData specifies the CA certificate data for the Kubernetes cluster. + // Corresponds to the AZURE_KUBERNETES_CA_DATA environment variable. + // At most one of AzureKubernetesCAData or AzureKubernetesCAFile should be set. + AzureKubernetesCAData string + + // AzureKubernetesCAFile specifies the path to the CA certificate file for the Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_CA_FILE environment variable. + // At most one of AzureKubernetesCAData or AzureKubernetesCAFile should be set. + AzureKubernetesCAFile string + + // AzureKubernetesSNIName specifies the name of the SNI for Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_SNI_NAME environment variable. + AzureKubernetesSNIName string + + // AzureKubernetesTokenProxy specifies the URL of the custom token proxy for the Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_TOKEN_PROXY environment variable. + AzureKubernetesTokenProxy string +} + +func parseTokenProxyURL(endpoint string) (*url.URL, error) { + tokenProxy, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("failed to parse custom token proxy URL %q: %s", endpoint, err) + } + if tokenProxy.Scheme != "https" { + return nil, fmt.Errorf("custom token endpoint must use https scheme, got %q", tokenProxy.Scheme) + } + if tokenProxy.User != nil { + return nil, fmt.Errorf("custom token endpoint URL %q must not contain user info", tokenProxy) + } + if tokenProxy.RawQuery != "" { + return nil, fmt.Errorf("custom token endpoint URL %q must not contain a query", tokenProxy) + } + if tokenProxy.EscapedFragment() != "" { + return nil, fmt.Errorf("custom token endpoint URL %q must not contain a fragment", tokenProxy) + } + if tokenProxy.EscapedPath() == "" { + // if the path is empty, set it to "/" to avoid stripping the path from req.URL + tokenProxy.Path = "/" + } + return tokenProxy, nil +} + +func (o *Options) defaults() { + if o.AzureKubernetesTokenProxy == "" { + o.AzureKubernetesTokenProxy = os.Getenv(AzureKubernetesTokenProxy) + } + if o.AzureKubernetesSNIName == "" { + o.AzureKubernetesSNIName = os.Getenv(AzureKubernetesSNIName) + } + if o.AzureKubernetesCAFile == "" { + o.AzureKubernetesCAFile = os.Getenv(AzureKubernetesCAFile) + } + if o.AzureKubernetesCAData == "" { + o.AzureKubernetesCAData = os.Getenv(AzureKubernetesCAData) + } +} + +var ( + errCustomEndpointSetWithoutTokenProxy = errors.New( + "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related settings are present", + ) + errCustomEndpointMultipleCASourcesSet = errors.New( + "only one of AzureKubernetesCAFile or AzureKubernetesCAData can be specified", + ) +) + +func noopConfigure(*policy.ClientOptions) { + // no-op +} + +// Apply returns a function that configures the client options to use the custom token proxy. +func Apply(opts *Options) (func(*policy.ClientOptions), error) { + if opts == nil { + return noopConfigure, nil + } + + opts.defaults() + + if opts.AzureKubernetesTokenProxy == "" { + // custom token proxy is not set, while other Kubernetes-related environment variables are present, + // this is likely a configuration issue so erroring out to avoid misconfiguration + if opts.AzureKubernetesSNIName != "" || opts.AzureKubernetesCAFile != "" || opts.AzureKubernetesCAData != "" { + return nil, errCustomEndpointSetWithoutTokenProxy + } + + return noopConfigure, nil + } + + tokenProxy, err := parseTokenProxyURL(opts.AzureKubernetesTokenProxy) + if err != nil { + return nil, err + } + + // CAFile and CAData are mutually exclusive, at most one can be set. + // If none of CAFile or CAData are set, the default system CA pool will be used. + if opts.AzureKubernetesCAFile != "" && opts.AzureKubernetesCAData != "" { + return nil, errCustomEndpointMultipleCASourcesSet + } + + // preload the transport + t := &transport{ + caFile: opts.AzureKubernetesCAFile, + caData: []byte(opts.AzureKubernetesCAData), + sniName: opts.AzureKubernetesSNIName, + tokenProxy: tokenProxy, + } + if _, err := t.getTokenTransporter(); err != nil { + return nil, err + } + + return func(clientOptions *policy.ClientOptions) { + clientOptions.Transport = t + }, nil +} diff --git a/sdk/azidentity/internal/customtokenproxy/configuration_test.go b/sdk/azidentity/internal/customtokenproxy/configuration_test.go new file mode 100644 index 000000000000..ee44c9c15caf --- /dev/null +++ b/sdk/azidentity/internal/customtokenproxy/configuration_test.go @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package customtokenproxy + +import ( + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +func TestParseTokenProxyURL(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + endpoint string + check func(t testing.TB, u *url.URL, err error) + }{ + { + name: "valid https endpoint without path", + endpoint: "https://example.com", + check: func(t testing.TB, u *url.URL, err error) { + require.NoError(t, err) + require.Equal(t, "https", u.Scheme) + require.Equal(t, "example.com", u.Host) + require.Equal(t, "", u.RawQuery) + require.Equal(t, "", u.Fragment) + require.Equal(t, "/", u.Path, "should set path to '/' if not present") + }, + }, + { + name: "valid https endpoint with path", + endpoint: "https://example.com/token/path", + check: func(t testing.TB, u *url.URL, err error) { + require.NoError(t, err) + require.Equal(t, "/token/path", u.Path) + }, + }, + { + name: "reject non-https scheme", + endpoint: "http://example.com", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "https scheme") + }, + }, + { + name: "reject user info", + endpoint: "https://user:pass@example.com/token", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "must not contain user info") + }, + }, + { + name: "reject query params", + endpoint: "https://example.com/token?foo=bar", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "must not contain a query") + }, + }, + { + name: "reject fragment", + endpoint: "https://example.com/token#frag", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "must not contain a fragment") + }, + }, + { + name: "reject unparseable URL", + endpoint: "https://example.com/%zz", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "failed to parse custom token proxy URL") + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + u, err := parseTokenProxyURL(c.endpoint) + c.check(t, u, err) + }) + } +} + +func TestOptions_Configure(t *testing.T) { + var ( + testCAData = string(createTestCA(t)) + testCAFile = createTestCAFile(t) + ) + + tests := []struct { + Name string + Envs map[string]string + Options Options + ClientOptions policy.ClientOptions + ExpectErr bool + AssertErr func(t testing.TB, err error) + ExpectTransport bool + }{ + { + Name: "no custom endpoint", + Envs: map[string]string{}, + Options: Options{}, + ExpectErr: false, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with minimal settings", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint enabled with CA file + SNI", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAFile: testCAFile, + AzureKubernetesSNIName: "custom-sni.example.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint enabled with invalid CA file", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAFile: "/non/existent/path/to/custom-ca-file.pem", + }, + ExpectErr: true, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with CA file contains invalid CA data", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAFile: func() string { + t.Helper() + + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "invalid-ca-file.pem") + require.NoError(t, os.WriteFile(caFile, []byte("invalid-ca-cert"), 0600)) + return caFile + }(), + }, + ExpectErr: true, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with CA data + SNI", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAData: testCAData, + AzureKubernetesSNIName: "custom-sni.example.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint enabled with invalid CA data", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAData: string("invalid-ca-cert"), + }, + ExpectErr: true, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with SNI", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesSNIName: "custom-sni.example.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint disabled with extra environment variables", + Options: Options{ + AzureKubernetesSNIName: "custom-sni.example.com", + }, + ExpectErr: true, + AssertErr: func(t testing.TB, err error) { + require.ErrorIs(t, err, errCustomEndpointSetWithoutTokenProxy) + }, + }, + { + Name: "custom endpoint enabled with both CAData and CAFile", + Options: Options{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAData: testCAData, + AzureKubernetesCAFile: testCAFile, + }, + ExpectErr: true, + AssertErr: func(t testing.TB, err error) { + require.ErrorIs(t, err, errCustomEndpointMultipleCASourcesSet) + }, + }, + { + Name: "custom endpoint enabled with invalid endpoint", + Options: Options{ + // http endpoint is not allowed + AzureKubernetesTokenProxy: "http://custom-endpoint.com", + }, + ExpectErr: true, + }, + { + Name: "set by environment variables", + Envs: map[string]string{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesCAFile: testCAFile, + AzureKubernetesSNIName: "custom-sni.example.com", + }, + Options: Options{}, + ExpectErr: false, + ExpectTransport: true, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + if len(tt.Envs) > 0 { + for k, v := range tt.Envs { + t.Setenv(k, v) + } + } + + mutateClientOptions, err := Apply(&tt.Options) + if tt.ExpectErr { + require.Error(t, err) + if tt.AssertErr != nil { + tt.AssertErr(t, err) + } + return + } + + require.NoError(t, err) + + mutateClientOptions(&tt.ClientOptions) + if tt.ExpectTransport { + require.NotNil(t, tt.ClientOptions.Transport) + require.IsType(t, &transport{}, tt.ClientOptions.Transport) + } else { + require.Nil(t, tt.ClientOptions.Transport) + } + }) + } +} + +func TestConfiguration_defaults(t *testing.T) { + t.Run("fills from env when empty", func(t *testing.T) { + expected := map[string]string{ + AzureKubernetesTokenProxy: "https://custom-endpoint.com", + AzureKubernetesSNIName: "sni.example.com", + AzureKubernetesCAFile: "/path/to/ca.pem", + AzureKubernetesCAData: "pem-data", + } + for k, v := range expected { + t.Setenv(k, v) + } + + opts := Options{} + opts.defaults() + + require.Equal(t, expected[AzureKubernetesTokenProxy], opts.AzureKubernetesTokenProxy) + require.Equal(t, expected[AzureKubernetesSNIName], opts.AzureKubernetesSNIName) + require.Equal(t, expected[AzureKubernetesCAFile], opts.AzureKubernetesCAFile) + require.Equal(t, expected[AzureKubernetesCAData], opts.AzureKubernetesCAData) + }) + + t.Run("preserves explicit values", func(t *testing.T) { + t.Setenv(AzureKubernetesTokenProxy, "https://env-value.com") + opts := Options{ + AzureKubernetesTokenProxy: "https://explicit.com", + AzureKubernetesSNIName: "explicit-sni", + AzureKubernetesCAFile: "/explicit/ca.pem", + AzureKubernetesCAData: "explicit-ca-data", + } + + opts.defaults() + + require.Equal(t, "https://explicit.com", opts.AzureKubernetesTokenProxy) + require.Equal(t, "explicit-sni", opts.AzureKubernetesSNIName) + require.Equal(t, "/explicit/ca.pem", opts.AzureKubernetesCAFile) + require.Equal(t, "explicit-ca-data", opts.AzureKubernetesCAData) + }) +} diff --git a/sdk/azidentity/internal/customtokenproxy/transport.go b/sdk/azidentity/internal/customtokenproxy/transport.go index 6c0fc6244203..db4fe93905d0 100644 --- a/sdk/azidentity/internal/customtokenproxy/transport.go +++ b/sdk/azidentity/internal/customtokenproxy/transport.go @@ -13,49 +13,6 @@ import ( "net/url" "os" "time" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" -) - -const ( - AzureKubernetesCAData = "AZURE_KUBERNETES_CA_DATA" - AzureKubernetesCAFile = "AZURE_KUBERNETES_CA_FILE" - AzureKubernetesSNIName = "AZURE_KUBERNETES_SNI_NAME" - - AzureKubernetesTokenProxy = "AZURE_KUBERNETES_TOKEN_PROXY" -) - -func parseAndValidate(endpoint string) (*url.URL, error) { - tokenProxy, err := url.Parse(endpoint) - if err != nil { - return nil, fmt.Errorf("failed to parse custom token proxy URL %q: %s", endpoint, err) - } - if tokenProxy.Scheme != "https" { - return nil, fmt.Errorf("custom token endpoint must use https scheme, got %q", tokenProxy.Scheme) - } - if tokenProxy.User != nil { - return nil, fmt.Errorf("custom token endpoint URL %q must not contain user info", tokenProxy) - } - if tokenProxy.RawQuery != "" { - return nil, fmt.Errorf("custom token endpoint URL %q must not contain a query", tokenProxy) - } - if tokenProxy.EscapedFragment() != "" { - return nil, fmt.Errorf("custom token endpoint URL %q must not contain a fragment", tokenProxy) - } - if tokenProxy.EscapedPath() == "" { - // if the path is empty, set it to "/" to avoid stripping the path from req.URL - tokenProxy.Path = "/" - } - return tokenProxy, nil -} - -var ( - errCustomEndpointEnvSetWithoutTokenProxy = errors.New( - "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present", - ) - errCustomEndpointMultipleCASourcesSet = errors.New( - "only one of AZURE_KUBERNETES_CA_FILE and AZURE_KUBERNETES_CA_DATA can be specified", - ) ) func createTransport(sniName string, caPool *x509.CertPool) *http.Transport { @@ -82,49 +39,6 @@ func createTransport(sniName string, caPool *x509.CertPool) *http.Transport { return transport } -// Configure configures custom token endpoint mode if the required environment variables are present. -func Configure(clientOptions *policy.ClientOptions) error { - kubernetesTokenProxyStr := os.Getenv(AzureKubernetesTokenProxy) - - kubernetesSNIName := os.Getenv(AzureKubernetesSNIName) - kubernetesCAFile := os.Getenv(AzureKubernetesCAFile) - kubernetesCAData := os.Getenv(AzureKubernetesCAData) - - if kubernetesTokenProxyStr == "" { - // custom token proxy is not set, while other Kubernetes-related environment variables are present, - // this is likely a configuration issue so erroring out to avoid misconfiguration - if kubernetesSNIName != "" || kubernetesCAFile != "" || kubernetesCAData != "" { - return errCustomEndpointEnvSetWithoutTokenProxy - } - - return nil - } - tokenProxy, err := parseAndValidate(kubernetesTokenProxyStr) - if err != nil { - return err - } - - // CAFile and CAData are mutually exclusive, at most one can be set. - // If none of CAFile or CAData are set, the default system CA pool will be used. - if kubernetesCAFile != "" && kubernetesCAData != "" { - return errCustomEndpointMultipleCASourcesSet - } - - // preload the transport - t := &transport{ - caFile: kubernetesCAFile, - caData: []byte(kubernetesCAData), - sniName: kubernetesSNIName, - tokenProxy: tokenProxy, - } - if _, err := t.getTokenTransporter(); err != nil { - return err - } - - clientOptions.Transport = t - return nil -} - // transport redirects requests to the configured proxy. // // Lock is not needed for internal caData as this transport is called under confidentialClient's lock. diff --git a/sdk/azidentity/internal/customtokenproxy/transport_test.go b/sdk/azidentity/internal/customtokenproxy/transport_test.go index c74a11a9169b..97834a4711de 100644 --- a/sdk/azidentity/internal/customtokenproxy/transport_test.go +++ b/sdk/azidentity/internal/customtokenproxy/transport_test.go @@ -19,237 +19,9 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/stretchr/testify/require" ) -func TestParseAndValidate(t *testing.T) { - cases := []struct { - name string - endpoint string - check func(t testing.TB, u *url.URL, err error) - }{ - { - name: "valid https endpoint without path", - endpoint: "https://example.com", - check: func(t testing.TB, u *url.URL, err error) { - require.NoError(t, err) - require.Equal(t, "https", u.Scheme) - require.Equal(t, "example.com", u.Host) - require.Equal(t, "", u.RawQuery) - require.Equal(t, "", u.Fragment) - require.Equal(t, "/", u.Path, "should set path to '/' if not present") - }, - }, - { - name: "valid https endpoint with path", - endpoint: "https://example.com/token/path", - check: func(t testing.TB, u *url.URL, err error) { - require.NoError(t, err) - require.Equal(t, "/token/path", u.Path) - }, - }, - { - name: "reject non-https scheme", - endpoint: "http://example.com", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "https scheme") - }, - }, - { - name: "reject user info", - endpoint: "https://user:pass@example.com/token", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "must not contain user info") - }, - }, - { - name: "reject query params", - endpoint: "https://example.com/token?foo=bar", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "must not contain a query") - }, - }, - { - name: "reject fragment", - endpoint: "https://example.com/token#frag", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "must not contain a fragment") - }, - }, - { - name: "reject unparseable URL", - endpoint: "https://example.com/%zz", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "failed to parse custom token proxy URL") - }, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - u, err := parseAndValidate(c.endpoint) - c.check(t, u, err) - }) - } -} - -func TestConfigure(t *testing.T) { - var ( - testCAData = string(createTestCA(t)) - testCAFile = createTestCAFile(t) - ) - - cases := []struct { - name string - envs map[string]string - clientOptions policy.ClientOptions - - expectErr bool - checkErr func(t testing.TB, err error) // optional check on error - expectTransport bool - }{ - { - name: "no custom endpoint", - expectErr: false, - expectTransport: false, - }, - { - name: "custom endpoint enabled with minimal settings", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint enabled with CA file + SNI", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: testCAFile, - AzureKubernetesSNIName: "custom-sni.example.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint enabled with invalid CA file", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: "/non/existent/path/to/custom-ca-file.pem", - }, - expectTransport: false, - }, - { - name: "custom endpoint enabled with CA file contains invalid CA data", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: func() string { - t.Helper() - - tempDir := t.TempDir() - caFile := filepath.Join(tempDir, "invalid-ca-file.pem") - require.NoError(t, os.WriteFile(caFile, []byte("invalid-ca-cert"), 0600)) - return caFile - }(), - }, - expectTransport: false, - }, - { - name: "custom endpoint enabled with CA data + SNI", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: testCAData, - AzureKubernetesSNIName: "custom-sni.example.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint enabled with invalid CA data", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: string("invalid-ca-cert"), - }, - expectTransport: false, - }, - { - name: "custom endpoint enabled with SNI", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesSNIName: "custom-sni.example.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint disabled with extra environment variables", - expectErr: true, - envs: map[string]string{ - AzureKubernetesSNIName: "custom-sni.example.com", - }, - checkErr: func(t testing.TB, err error) { - require.ErrorIs(t, err, errCustomEndpointEnvSetWithoutTokenProxy) - }, - }, - { - name: "custom endpoint enabled with both CAData and CAFile", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: testCAData, - AzureKubernetesCAFile: testCAFile, - }, - checkErr: func(t testing.TB, err error) { - require.ErrorIs(t, err, errCustomEndpointMultipleCASourcesSet) - }, - }, - { - name: "custom endpoint enabled with invalid endpoint", - expectErr: true, - envs: map[string]string{ - // http endpoint is not allowed - AzureKubernetesTokenProxy: "http://custom-endpoint.com", - }, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - if len(c.envs) > 0 { - for k, v := range c.envs { - t.Setenv(k, v) - } - } - err := Configure(&c.clientOptions) - - if c.expectErr { - require.Error(t, err) - if c.checkErr != nil { - c.checkErr(t, err) - } - return - } - - require.NoError(t, err) - if c.expectTransport { - require.NotNil(t, c.clientOptions.Transport) - require.IsType(t, &transport{}, c.clientOptions.Transport) - } else { - require.Nil(t, c.clientOptions.Transport) - } - }) - } -} - // createTestCA creates a valid CA as bytes func createTestCA(t testing.TB) []byte { t.Helper() @@ -439,7 +211,6 @@ func TestGetTokenTransporter_reentry(t *testing.T) { func TestTransport_Do(t *testing.T) { mux := http.NewServeMux() testServer := httptest.NewTLSServer(mux) - ca := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: testServer.Certificate().Raw}) require.NotEmpty(t, ca) diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index 11b43e0904f3..c07e897fa092 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -29,6 +29,8 @@ type WorkloadIdentityCredential struct { mtx *sync.RWMutex } +type WorkloadIdentityAzureProxyOptions = customtokenproxy.Options + // WorkloadIdentityCredentialOptions contains optional parameters for WorkloadIdentityCredential. type WorkloadIdentityCredentialOptions struct { azcore.ClientOptions @@ -52,8 +54,9 @@ type WorkloadIdentityCredentialOptions struct { // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool - // EnableAzureProxy determines whether the credential reads proxy configuration from environment variables. When - // this value is true and proxy configuration isn't present or this value is false, the credential will request + // EnableAzureProxy determines whether the credential reads proxy configuration from environment variables or + // from the AzureTokenProxyOptions field. + // When this value is true and proxy configuration isn't present or this value is false, the credential will request // tokens directly from Entra ID. // // The proxy feature is designed for applications that deploy to many clusters and clusters that host many @@ -61,6 +64,10 @@ type WorkloadIdentityCredentialOptions struct { // to set this option: https://learn.microsoft.com/azure/aks/identity-bindings-concepts EnableAzureProxy bool + // AzureProxy specifies the options for the Azure proxy. + // If EnableAzureProxy is false, this field is ignored. + AzureProxy WorkloadIdentityAzureProxyOptions + // TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID. TenantID string @@ -104,9 +111,11 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } if options.EnableAzureProxy { - if err := customtokenproxy.Configure(&caco.ClientOptions); err != nil { + mutateClientOptions, err := customtokenproxy.Apply(&options.AzureProxy) + if err != nil { return nil, err } + mutateClientOptions(&caco.ClientOptions) } cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, caco) diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index 85fc680e4be0..226d8b88caa8 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -388,6 +388,58 @@ func (c *customTokenRequestPolicyFlowCheck) Validate(t testing.TB, req *http.Req require.Equal(t, c.requiredHeaderValue, req.Header.Get(c.requiredHeaderKey)) } +func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") + if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() + + customTokenEndointServerCalledTimes := new(atomic.Int32) + customTokenEndointServer, caData := startTestTokenEndpointWithCAData( + t, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + customTokenEndointServerCalledTimes.Add(1) + + policyFlowCheck.Validate(t, req) + + require.NoError(t, req.ParseForm()) + require.NotEmpty(t, req.PostForm) + + require.Contains(t, req.PostForm, "client_assertion") + require.Equal(t, req.PostForm.Get("client_assertion"), testClientAssertion) + + require.Contains(t, req.PostForm, "client_id") + require.Equal(t, req.PostForm.Get("client_id"), fakeClientID) + + _, _ = w.Write(accessTokenRespSuccess) + }), + ) + + clientOptions := policy.ClientOptions{ + PerCallPolicies: []policy.Policy{ + policyFlowCheck.Policy(), + }, + } + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + ClientOptions: clientOptions, + EnableAzureProxy: true, + TenantID: fakeTenantID, + TokenFilePath: tempFile, + AzureProxy: WorkloadIdentityAzureProxyOptions{ + AzureKubernetesTokenProxy: customTokenEndointServer.URL, + AzureKubernetesCAData: caData, + }, + }) + require.NoError(t, err) + require.Nil(t, clientOptions.Transport, "constructor shouldn't mutate caller's ClientOptions") + + testGetTokenSuccess(t, cred) + + require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) +} + func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { From 8dd1286089913e01ba2fbe133b247af94df42012 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Fri, 21 Nov 2025 11:06:40 +1100 Subject: [PATCH 2/8] Add optional GetFederatedToken --- sdk/azidentity/workload_identity.go | 30 +++++++++++++++++------- sdk/azidentity/workload_identity_test.go | 28 ++++++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index c07e897fa092..5d5a525d43c6 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -23,10 +23,11 @@ const credNameWorkloadIdentity = "WorkloadIdentityCredential" // // [Azure Kubernetes Service documentation]: https://learn.microsoft.com/azure/aks/workload-identity-overview type WorkloadIdentityCredential struct { - assertion, file string - cred *ClientAssertionCredential - expires time.Time - mtx *sync.RWMutex + assertion, file string + getAssertionOverride func(ctx context.Context) (string, error) + cred *ClientAssertionCredential + expires time.Time + mtx *sync.RWMutex } type WorkloadIdentityAzureProxyOptions = customtokenproxy.Options @@ -73,7 +74,12 @@ type WorkloadIdentityCredentialOptions struct { // TokenFilePath is the path of a file containing a Kubernetes service account token. Defaults to the value of the // environment variable AZURE_FEDERATED_TOKEN_FILE. + // If GetFederatedToken is set, this field is ignored. TokenFilePath string + + // GetFederatedToken defines an optional func to get the Kubernetes service account token. + // If this function is set, it will be used to get the token instead of reading it from TokenFilePath. + GetFederatedToken func(ctx context.Context) (string, error) } // NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read @@ -90,9 +96,9 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } } file := options.TokenFilePath - if file == "" { + if file == "" && options.GetFederatedToken == nil { if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok { - return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options") + return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath or GetFederatedToken in the options") } } tenantID := options.TenantID @@ -102,7 +108,11 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } } - w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}} + w := WorkloadIdentityCredential{ + file: file, + getAssertionOverride: options.GetFederatedToken, + mtx: &sync.RWMutex{}, + } caco := &ClientAssertionCredentialOptions{ AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, Cache: options.Cache, @@ -139,7 +149,11 @@ func (w *WorkloadIdentityCredential) GetToken(ctx context.Context, opts policy.T // getAssertion returns the specified file's content, which is expected to be a Kubernetes service account token. // Kubernetes is responsible for updating the file as service account tokens expire. -func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, error) { +func (w *WorkloadIdentityCredential) getAssertion(ctx context.Context) (string, error) { + if w.getAssertionOverride != nil { + return w.getAssertionOverride(ctx) + } + w.mtx.RLock() if w.expires.Before(time.Now()) { // ensure only one goroutine at a time updates the assertion diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index 226d8b88caa8..44ea8267ddcd 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -388,6 +388,34 @@ func (c *customTokenRequestPolicyFlowCheck) Validate(t testing.TB, req *http.Req require.Equal(t, c.requiredHeaderValue, req.Header.Get(c.requiredHeaderKey)) } +func TestWorkloadIdentityCredential_GetFederatedTokenOverride(t *testing.T) { + const overrideToken = "override-token" + called := new(atomic.Int32) + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + TenantID: fakeTenantID, + ClientOptions: policy.ClientOptions{ + Transport: &mockSTS{ + tokenRequestCallback: func(req *http.Request) *http.Response { + require.NoError(t, req.ParseForm()) + require.Equal(t, overrideToken, req.PostForm.Get("client_assertion")) + called.Add(1) + return nil + }, + }, + }, + GetFederatedToken: func(ctx context.Context) (string, error) { + return overrideToken, nil + }, + }) + require.NoError(t, err) + + tk, err := cred.GetToken(context.Background(), testTRO) + require.NoError(t, err) + require.Equal(t, tokenValue, tk.Token) + require.Equal(t, int32(1), called.Load()) +} + func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { From cd5a6e330ea92b6f66bebbf585fa8db1d82b3669 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 25 Nov 2025 10:11:19 +1100 Subject: [PATCH 3/8] Add missing doc --- sdk/azidentity/workload_identity.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index 5d5a525d43c6..65ab4ea5ed8b 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -30,6 +30,7 @@ type WorkloadIdentityCredential struct { mtx *sync.RWMutex } +// WorkloadIdentityAzureProxyOptions contains optional parameters for configuring WorkloadIdentity Azure Proxy. type WorkloadIdentityAzureProxyOptions = customtokenproxy.Options // WorkloadIdentityCredentialOptions contains optional parameters for WorkloadIdentityCredential. From 5a14d1b0c8ff5c52271dabb951366a96950e68c3 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 25 Nov 2025 10:23:06 +1100 Subject: [PATCH 4/8] Update sdk/azidentity/workload_identity.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- sdk/azidentity/workload_identity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index 65ab4ea5ed8b..de70aeb3a44d 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -57,7 +57,7 @@ type WorkloadIdentityCredentialOptions struct { DisableInstanceDiscovery bool // EnableAzureProxy determines whether the credential reads proxy configuration from environment variables or - // from the AzureTokenProxyOptions field. + // from the AzureProxy field. // When this value is true and proxy configuration isn't present or this value is false, the credential will request // tokens directly from Entra ID. // From 05ed0cec0be25bf3b6f8b1a2a1ad5960a52517da Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 25 Nov 2025 10:23:19 +1100 Subject: [PATCH 5/8] Update sdk/azidentity/workload_identity.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- sdk/azidentity/workload_identity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index de70aeb3a44d..3ceb4a0f95af 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -99,7 +99,7 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( file := options.TokenFilePath if file == "" && options.GetFederatedToken == nil { if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok { - return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath or GetFederatedToken in the options") + return nil, errors.New("no token source specified. Check pod configuration or set GetFederatedToken or TokenFilePath in the options") } } tenantID := options.TenantID From 9f4d2738d54a0a8557a869047b7c33e287ad076b Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 25 Nov 2025 21:18:43 +1100 Subject: [PATCH 6/8] Fix typo --- sdk/azidentity/workload_identity_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index 44ea8267ddcd..3867330d4cc8 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -423,11 +423,11 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T } policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() - customTokenEndointServerCalledTimes := new(atomic.Int32) - customTokenEndointServer, caData := startTestTokenEndpointWithCAData( + customTokenEndpointServerCalledTimes := new(atomic.Int32) + customTokenEndpointServer, caData := startTestTokenEndpointWithCAData( t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - customTokenEndointServerCalledTimes.Add(1) + customTokenEndpointServerCalledTimes.Add(1) policyFlowCheck.Validate(t, req) @@ -456,7 +456,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T TenantID: fakeTenantID, TokenFilePath: tempFile, AzureProxy: WorkloadIdentityAzureProxyOptions{ - AzureKubernetesTokenProxy: customTokenEndointServer.URL, + AzureKubernetesTokenProxy: customTokenEndpointServer.URL, AzureKubernetesCAData: caData, }, }) @@ -465,7 +465,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T testGetTokenSuccess(t, cred) - require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) + require.Equal(t, int32(1), customTokenEndpointServerCalledTimes.Load()) } func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) { @@ -475,11 +475,11 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) } policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() - customTokenEndointServerCalledTimes := new(atomic.Int32) - customTokenEndointServer, caData := startTestTokenEndpointWithCAData( + customTokenEndpointServerCalledTimes := new(atomic.Int32) + customTokenEndpointServer, caData := startTestTokenEndpointWithCAData( t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - customTokenEndointServerCalledTimes.Add(1) + customTokenEndpointServerCalledTimes.Add(1) policyFlowCheck.Validate(t, req) @@ -496,7 +496,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndointServer.URL) + t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndpointServer.URL) t.Setenv(customtokenproxy.AzureKubernetesCAData, caData) clientOptions := policy.ClientOptions{ @@ -516,7 +516,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) testGetTokenSuccess(t, cred) - require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) + require.Equal(t, int32(1), customTokenEndpointServerCalledTimes.Load()) } func TestWorkloadIdentityCredential_CustomTokenEndpoint_InvalidSettings(t *testing.T) { From 06c9eb39d74dffd7ac0786456e8125e949f16127 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 2 Dec 2025 11:48:36 +1100 Subject: [PATCH 7/8] Refactor options --- .../default_azure_credential_test.go | 4 +- .../customtokenproxy/configuration.go | 80 ++++----- .../customtokenproxy/configuration_test.go | 169 ++++++++++-------- .../internal/exported/custom_token_proxy.go | 25 +++ sdk/azidentity/workload_identity.go | 34 ++-- sdk/azidentity/workload_identity_test.go | 85 +++++++-- 6 files changed, 252 insertions(+), 145 deletions(-) create mode 100644 sdk/azidentity/internal/exported/custom_token_proxy.go diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 8a98b8c154a4..082452004442 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -357,8 +357,8 @@ func TestDefaultAzureCredential_WorkloadIdentity(t *testing.T) { t.Setenv(azureTokenCredentials, credNameWorkloadIdentity) // these values should trigger validation errors if WorkloadIdentityCredential // tries to configure identity binding mode... - t.Setenv(customtokenproxy.AzureKubernetesCAData, "not a valid cert") - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, "http://timeout.local&fail=yes#please") + t.Setenv(customtokenproxy.EnvAzureKubernetesCAData, "not a valid cert") + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, "http://timeout.local&fail=yes#please") cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, diff --git a/sdk/azidentity/internal/customtokenproxy/configuration.go b/sdk/azidentity/internal/customtokenproxy/configuration.go index c638671b0458..0dcd9840c19a 100644 --- a/sdk/azidentity/internal/customtokenproxy/configuration.go +++ b/sdk/azidentity/internal/customtokenproxy/configuration.go @@ -10,35 +10,32 @@ import ( "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/exported" ) const ( - AzureKubernetesCAData = "AZURE_KUBERNETES_CA_DATA" - AzureKubernetesCAFile = "AZURE_KUBERNETES_CA_FILE" - AzureKubernetesSNIName = "AZURE_KUBERNETES_SNI_NAME" - - AzureKubernetesTokenProxy = "AZURE_KUBERNETES_TOKEN_PROXY" + EnvAzureKubernetesCAData = "AZURE_KUBERNETES_CA_DATA" + EnvAzureKubernetesCAFile = "AZURE_KUBERNETES_CA_FILE" + EnvAzureKubernetesSNIName = "AZURE_KUBERNETES_SNI_NAME" + EnvAzureKubernetesTokenProxy = "AZURE_KUBERNETES_TOKEN_PROXY" ) -// Options contains optional parameters for custom token proxy configuration. -type Options struct { - // AzureKubernetesCAData specifies the CA certificate data for the Kubernetes cluster. - // Corresponds to the AZURE_KUBERNETES_CA_DATA environment variable. - // At most one of AzureKubernetesCAData or AzureKubernetesCAFile should be set. - AzureKubernetesCAData string - - // AzureKubernetesCAFile specifies the path to the CA certificate file for the Kubernetes cluster. - // This field corresponds to the AZURE_KUBERNETES_CA_FILE environment variable. - // At most one of AzureKubernetesCAData or AzureKubernetesCAFile should be set. - AzureKubernetesCAFile string - - // AzureKubernetesSNIName specifies the name of the SNI for Kubernetes cluster. - // This field corresponds to the AZURE_KUBERNETES_SNI_NAME environment variable. - AzureKubernetesSNIName string - - // AzureKubernetesTokenProxy specifies the URL of the custom token proxy for the Kubernetes cluster. - // This field corresponds to the AZURE_KUBERNETES_TOKEN_PROXY environment variable. - AzureKubernetesTokenProxy string +func readOptionsFromEnv() *exported.CustomTokenProxyOptions { + return &exported.CustomTokenProxyOptions{ + TokenProxy: os.Getenv(EnvAzureKubernetesTokenProxy), + SNIName: os.Getenv(EnvAzureKubernetesSNIName), + CAFile: os.Getenv(EnvAzureKubernetesCAFile), + CAData: os.Getenv(EnvAzureKubernetesCAData), + } +} + +func backfillOptionsFromEnv(opts *exported.CustomTokenProxyOptions) { + if opts.CAData != "" || opts.CAFile != "" || opts.SNIName != "" || opts.TokenProxy != "" { + return + } + + // only backfill if all fields are empty + *opts = *readOptionsFromEnv() } func parseTokenProxyURL(endpoint string) (*url.URL, error) { @@ -65,21 +62,6 @@ func parseTokenProxyURL(endpoint string) (*url.URL, error) { return tokenProxy, nil } -func (o *Options) defaults() { - if o.AzureKubernetesTokenProxy == "" { - o.AzureKubernetesTokenProxy = os.Getenv(AzureKubernetesTokenProxy) - } - if o.AzureKubernetesSNIName == "" { - o.AzureKubernetesSNIName = os.Getenv(AzureKubernetesSNIName) - } - if o.AzureKubernetesCAFile == "" { - o.AzureKubernetesCAFile = os.Getenv(AzureKubernetesCAFile) - } - if o.AzureKubernetesCAData == "" { - o.AzureKubernetesCAData = os.Getenv(AzureKubernetesCAData) - } -} - var ( errCustomEndpointSetWithoutTokenProxy = errors.New( "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related settings are present", @@ -93,40 +75,40 @@ func noopConfigure(*policy.ClientOptions) { // no-op } -// Apply returns a function that configures the client options to use the custom token proxy. -func Apply(opts *Options) (func(*policy.ClientOptions), error) { +// GetClientOptionsConfigurer returns a function that configures the client options to use the custom token proxy. +func GetClientOptionsConfigurer(opts *exported.CustomTokenProxyOptions) (func(*policy.ClientOptions), error) { if opts == nil { return noopConfigure, nil } - opts.defaults() + backfillOptionsFromEnv(opts) - if opts.AzureKubernetesTokenProxy == "" { + if opts.TokenProxy == "" { // custom token proxy is not set, while other Kubernetes-related environment variables are present, // this is likely a configuration issue so erroring out to avoid misconfiguration - if opts.AzureKubernetesSNIName != "" || opts.AzureKubernetesCAFile != "" || opts.AzureKubernetesCAData != "" { + if opts.SNIName != "" || opts.CAFile != "" || opts.CAData != "" { return nil, errCustomEndpointSetWithoutTokenProxy } return noopConfigure, nil } - tokenProxy, err := parseTokenProxyURL(opts.AzureKubernetesTokenProxy) + tokenProxy, err := parseTokenProxyURL(opts.TokenProxy) if err != nil { return nil, err } // CAFile and CAData are mutually exclusive, at most one can be set. // If none of CAFile or CAData are set, the default system CA pool will be used. - if opts.AzureKubernetesCAFile != "" && opts.AzureKubernetesCAData != "" { + if opts.CAFile != "" && opts.CAData != "" { return nil, errCustomEndpointMultipleCASourcesSet } // preload the transport t := &transport{ - caFile: opts.AzureKubernetesCAFile, - caData: []byte(opts.AzureKubernetesCAData), - sniName: opts.AzureKubernetesSNIName, + caFile: opts.CAFile, + caData: []byte(opts.CAData), + sniName: opts.SNIName, tokenProxy: tokenProxy, } if _, err := t.getTokenTransporter(); err != nil { diff --git a/sdk/azidentity/internal/customtokenproxy/configuration_test.go b/sdk/azidentity/internal/customtokenproxy/configuration_test.go index ee44c9c15caf..d20a034dfad1 100644 --- a/sdk/azidentity/internal/customtokenproxy/configuration_test.go +++ b/sdk/azidentity/internal/customtokenproxy/configuration_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/exported" ) func TestParseTokenProxyURL(t *testing.T) { @@ -103,7 +104,7 @@ func TestOptions_Configure(t *testing.T) { tests := []struct { Name string Envs map[string]string - Options Options + Options exported.CustomTokenProxyOptions ClientOptions policy.ClientOptions ExpectErr bool AssertErr func(t testing.TB, err error) @@ -112,42 +113,42 @@ func TestOptions_Configure(t *testing.T) { { Name: "no custom endpoint", Envs: map[string]string{}, - Options: Options{}, + Options: exported.CustomTokenProxyOptions{}, ExpectErr: false, ExpectTransport: false, }, { Name: "custom endpoint enabled with minimal settings", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", }, ExpectErr: false, ExpectTransport: true, }, { Name: "custom endpoint enabled with CA file + SNI", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: testCAFile, - AzureKubernetesSNIName: "custom-sni.example.com", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAFile: testCAFile, + SNIName: "custom-sni.example.com", }, ExpectErr: false, ExpectTransport: true, }, { Name: "custom endpoint enabled with invalid CA file", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: "/non/existent/path/to/custom-ca-file.pem", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAFile: "/non/existent/path/to/custom-ca-file.pem", }, ExpectErr: true, ExpectTransport: false, }, { Name: "custom endpoint enabled with CA file contains invalid CA data", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: func() string { + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAFile: func() string { t.Helper() tempDir := t.TempDir() @@ -161,36 +162,36 @@ func TestOptions_Configure(t *testing.T) { }, { Name: "custom endpoint enabled with CA data + SNI", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: testCAData, - AzureKubernetesSNIName: "custom-sni.example.com", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: testCAData, + SNIName: "custom-sni.example.com", }, ExpectErr: false, ExpectTransport: true, }, { Name: "custom endpoint enabled with invalid CA data", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: string("invalid-ca-cert"), + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: string("invalid-ca-cert"), }, ExpectErr: true, ExpectTransport: false, }, { Name: "custom endpoint enabled with SNI", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesSNIName: "custom-sni.example.com", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + SNIName: "custom-sni.example.com", }, ExpectErr: false, ExpectTransport: true, }, { Name: "custom endpoint disabled with extra environment variables", - Options: Options{ - AzureKubernetesSNIName: "custom-sni.example.com", + Options: exported.CustomTokenProxyOptions{ + SNIName: "custom-sni.example.com", }, ExpectErr: true, AssertErr: func(t testing.TB, err error) { @@ -199,10 +200,10 @@ func TestOptions_Configure(t *testing.T) { }, { Name: "custom endpoint enabled with both CAData and CAFile", - Options: Options{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: testCAData, - AzureKubernetesCAFile: testCAFile, + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: testCAData, + CAFile: testCAFile, }, ExpectErr: true, AssertErr: func(t testing.TB, err error) { @@ -211,20 +212,20 @@ func TestOptions_Configure(t *testing.T) { }, { Name: "custom endpoint enabled with invalid endpoint", - Options: Options{ + Options: exported.CustomTokenProxyOptions{ // http endpoint is not allowed - AzureKubernetesTokenProxy: "http://custom-endpoint.com", + TokenProxy: "http://custom-endpoint.com", }, ExpectErr: true, }, { Name: "set by environment variables", Envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: testCAFile, - AzureKubernetesSNIName: "custom-sni.example.com", + EnvAzureKubernetesTokenProxy: "https://custom-endpoint.com", + EnvAzureKubernetesCAFile: testCAFile, + EnvAzureKubernetesSNIName: "custom-sni.example.com", }, - Options: Options{}, + Options: exported.CustomTokenProxyOptions{}, ExpectErr: false, ExpectTransport: true, }, @@ -238,7 +239,7 @@ func TestOptions_Configure(t *testing.T) { } } - mutateClientOptions, err := Apply(&tt.Options) + mutateClientOptions, err := GetClientOptionsConfigurer(&tt.Options) if tt.ExpectErr { require.Error(t, err) if tt.AssertErr != nil { @@ -260,41 +261,67 @@ func TestOptions_Configure(t *testing.T) { } } -func TestConfiguration_defaults(t *testing.T) { - t.Run("fills from env when empty", func(t *testing.T) { - expected := map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesSNIName: "sni.example.com", - AzureKubernetesCAFile: "/path/to/ca.pem", - AzureKubernetesCAData: "pem-data", - } - for k, v := range expected { - t.Setenv(k, v) - } - - opts := Options{} - opts.defaults() - - require.Equal(t, expected[AzureKubernetesTokenProxy], opts.AzureKubernetesTokenProxy) - require.Equal(t, expected[AzureKubernetesSNIName], opts.AzureKubernetesSNIName) - require.Equal(t, expected[AzureKubernetesCAFile], opts.AzureKubernetesCAFile) - require.Equal(t, expected[AzureKubernetesCAData], opts.AzureKubernetesCAData) - }) +func TestBackfillOptionsFromEnv(t *testing.T) { + tests := []struct { + Name string + Options exported.CustomTokenProxyOptions + Envs map[string]string + Expected exported.CustomTokenProxyOptions + }{ + { + Name: "empty", + Options: exported.CustomTokenProxyOptions{}, + Envs: map[string]string{}, + Expected: exported.CustomTokenProxyOptions{}, + }, + { + Name: "options field is not nil", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: "testCAData", + CAFile: "testCAFile", + SNIName: "custom-sni.example.com", + }, + Envs: map[string]string{ + EnvAzureKubernetesTokenProxy: "https://endpoint-from-env.com", + EnvAzureKubernetesCAData: "ca-data-from-env", + EnvAzureKubernetesCAFile: "ca-file-from-env", + EnvAzureKubernetesSNIName: "sni-name-from-env", + }, + Expected: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: "testCAData", + CAFile: "testCAFile", + SNIName: "custom-sni.example.com", + }, + }, + { + Name: "options field is nil", + Options: exported.CustomTokenProxyOptions{}, + Envs: map[string]string{ + EnvAzureKubernetesTokenProxy: "https://endpoint-from-env.com", + EnvAzureKubernetesCAData: "ca-data-from-env", + EnvAzureKubernetesCAFile: "ca-file-from-env", + EnvAzureKubernetesSNIName: "sni-name-from-env", + }, + Expected: exported.CustomTokenProxyOptions{ + TokenProxy: "https://endpoint-from-env.com", + CAData: "ca-data-from-env", + CAFile: "ca-file-from-env", + SNIName: "sni-name-from-env", + }, + }, + } - t.Run("preserves explicit values", func(t *testing.T) { - t.Setenv(AzureKubernetesTokenProxy, "https://env-value.com") - opts := Options{ - AzureKubernetesTokenProxy: "https://explicit.com", - AzureKubernetesSNIName: "explicit-sni", - AzureKubernetesCAFile: "/explicit/ca.pem", - AzureKubernetesCAData: "explicit-ca-data", - } + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { - opts.defaults() + for k, v := range tt.Envs { + t.Setenv(k, v) + } - require.Equal(t, "https://explicit.com", opts.AzureKubernetesTokenProxy) - require.Equal(t, "explicit-sni", opts.AzureKubernetesSNIName) - require.Equal(t, "/explicit/ca.pem", opts.AzureKubernetesCAFile) - require.Equal(t, "explicit-ca-data", opts.AzureKubernetesCAData) - }) + backfillOptionsFromEnv(&tt.Options) + require.Equal(t, tt.Expected, tt.Options) + }) + } } diff --git a/sdk/azidentity/internal/exported/custom_token_proxy.go b/sdk/azidentity/internal/exported/custom_token_proxy.go new file mode 100644 index 000000000000..78b0ef9057b7 --- /dev/null +++ b/sdk/azidentity/internal/exported/custom_token_proxy.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +// CustomTokenProxyOptions contains optional parameters for custom token proxy configuration. +type CustomTokenProxyOptions struct { + // CAData specifies the CA certificate data for the Kubernetes cluster. + // Corresponds to the AZURE_KUBERNETES_CA_DATA environment variable. + // At most one of CAData or CAFile should be set. + CAData string + + // CAFile specifies the path to the CA certificate file for the Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_CA_FILE environment variable. + // At most one of CAData or CAFile should be set. + CAFile string + + // SNIName specifies the name of the SNI for Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_SNI_NAME environment variable. + SNIName string + + // TokenProxy specifies the URL of the custom token proxy for the Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_TOKEN_PROXY environment variable. + TokenProxy string +} diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index 3ceb4a0f95af..c2825bc7b955 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -14,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/customtokenproxy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/exported" ) const credNameWorkloadIdentity = "WorkloadIdentityCredential" @@ -30,8 +31,8 @@ type WorkloadIdentityCredential struct { mtx *sync.RWMutex } -// WorkloadIdentityAzureProxyOptions contains optional parameters for configuring WorkloadIdentity Azure Proxy. -type WorkloadIdentityAzureProxyOptions = customtokenproxy.Options +// WorkloadIdentityCustomTokenProxyOptions contains optional parameters for configuring WorkloadIdentity Custom Token Proxy. +type WorkloadIdentityCustomTokenProxyOptions = exported.CustomTokenProxyOptions // WorkloadIdentityCredentialOptions contains optional parameters for WorkloadIdentityCredential. type WorkloadIdentityCredentialOptions struct { @@ -57,7 +58,7 @@ type WorkloadIdentityCredentialOptions struct { DisableInstanceDiscovery bool // EnableAzureProxy determines whether the credential reads proxy configuration from environment variables or - // from the AzureProxy field. + // from the CustomTokenProxy field. // When this value is true and proxy configuration isn't present or this value is false, the credential will request // tokens directly from Entra ID. // @@ -66,20 +67,21 @@ type WorkloadIdentityCredentialOptions struct { // to set this option: https://learn.microsoft.com/azure/aks/identity-bindings-concepts EnableAzureProxy bool - // AzureProxy specifies the options for the Azure proxy. - // If EnableAzureProxy is false, this field is ignored. - AzureProxy WorkloadIdentityAzureProxyOptions + // CustomTokenProxy specifies the options for the custom token proxy. + // It should not be set if EnableAzureProxy is true. + CustomTokenProxy *WorkloadIdentityCustomTokenProxyOptions // TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID. TenantID string - // TokenFilePath is the path of a file containing a Kubernetes service account token. Defaults to the value of the - // environment variable AZURE_FEDERATED_TOKEN_FILE. - // If GetFederatedToken is set, this field is ignored. + // TokenFilePath is the path of a file containing a Kubernetes service account token. + // This field is mutually exclusive with GetFederatedToken. + // If neither is specified, the credential will attempt to read the token path from the AZURE_FEDERATED_TOKEN_FILE environment variable. TokenFilePath string // GetFederatedToken defines an optional func to get the Kubernetes service account token. - // If this function is set, it will be used to get the token instead of reading it from TokenFilePath. + // This field is mutually exclusive with TokenFilePath. + // If neither is specified, the credential will attempt to read the token path from the AZURE_FEDERATED_TOKEN_FILE environment variable. GetFederatedToken func(ctx context.Context) (string, error) } @@ -97,6 +99,9 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } } file := options.TokenFilePath + if file != "" && options.GetFederatedToken != nil { + return nil, errors.New("TokenFilePath and GetFederatedToken cannot be set at the same time") + } if file == "" && options.GetFederatedToken == nil { if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok { return nil, errors.New("no token source specified. Check pod configuration or set GetFederatedToken or TokenFilePath in the options") @@ -108,6 +113,9 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( return nil, errors.New("no tenant ID specified. Check pod configuration or set TenantID in the options") } } + if !options.EnableAzureProxy && options.CustomTokenProxy != nil { + return nil, errors.New("CustomTokenProxy should not be set if EnableAzureProxy is false") + } w := WorkloadIdentityCredential{ file: file, @@ -122,7 +130,11 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } if options.EnableAzureProxy { - mutateClientOptions, err := customtokenproxy.Apply(&options.AzureProxy) + customTokenProxyOptions := options.CustomTokenProxy + if customTokenProxyOptions == nil { + customTokenProxyOptions = &WorkloadIdentityCustomTokenProxyOptions{} + } + mutateClientOptions, err := customtokenproxy.GetClientOptionsConfigurer(customTokenProxyOptions) if err != nil { return nil, err } diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index 3867330d4cc8..5628061d0884 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -455,9 +455,9 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T EnableAzureProxy: true, TenantID: fakeTenantID, TokenFilePath: tempFile, - AzureProxy: WorkloadIdentityAzureProxyOptions{ - AzureKubernetesTokenProxy: customTokenEndpointServer.URL, - AzureKubernetesCAData: caData, + CustomTokenProxy: &WorkloadIdentityCustomTokenProxyOptions{ + TokenProxy: customTokenEndpointServer.URL, + CAData: caData, }, }) require.NoError(t, err) @@ -496,8 +496,8 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndpointServer.URL) - t.Setenv(customtokenproxy.AzureKubernetesCAData, caData) + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, customTokenEndpointServer.URL) + t.Setenv(customtokenproxy.EnvAzureKubernetesCAData, caData) clientOptions := policy.ClientOptions{ PerCallPolicies: []policy.Policy{ @@ -520,7 +520,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) } func TestWorkloadIdentityCredential_CustomTokenEndpoint_InvalidSettings(t *testing.T) { - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, "invalid-token-endpoint") + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, "invalid-token-endpoint") _, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientID: fakeClientID, EnableAzureProxy: true, @@ -558,11 +558,11 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAFile(t *testing.T) }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndointServer.URL) + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, customTokenEndointServer.URL) d := t.TempDir() caFile := filepath.Join(d, "test-ca-file") require.NoError(t, os.WriteFile(caFile, []byte(caData), 0600)) - t.Setenv(customtokenproxy.AzureKubernetesCAFile, caFile) + t.Setenv(customtokenproxy.EnvAzureKubernetesCAFile, caFile) clientOptions := policy.ClientOptions{ PerCallPolicies: []policy.Policy{ @@ -584,7 +584,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAFile(t *testing.T) require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) } -func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup(t *testing.T) { +func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup_FromEnv(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { t.Fatalf("failed to write token file: %v", err) @@ -616,13 +616,13 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup(t *testing.T) { }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndointServer.URL) - t.Setenv(customtokenproxy.AzureKubernetesSNIName, sniName) + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, customTokenEndointServer.URL) + t.Setenv(customtokenproxy.EnvAzureKubernetesSNIName, sniName) d := t.TempDir() caFile := filepath.Join(d, "test-ca-file") require.NoError(t, os.WriteFile(caFile, []byte(caData), 0600)) - t.Setenv(customtokenproxy.AzureKubernetesCAFile, caFile) + t.Setenv(customtokenproxy.EnvAzureKubernetesCAFile, caFile) clientOptions := policy.ClientOptions{ PerCallPolicies: []policy.Policy{ @@ -643,3 +643,64 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup(t *testing.T) { require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) } + +func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup_FromOptions(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") + if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() + sniName := "test-sni.example.com" + + customTokenEndointServerCalledTimes := new(atomic.Int32) + customTokenEndointServer, caData := startTestTokenEndpointWithCAData( + t, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + customTokenEndointServerCalledTimes.Add(1) + + policyFlowCheck.Validate(t, req) + + require.NotNil(t, req.TLS) + require.Equal(t, req.TLS.ServerName, sniName, "when SNI is set, request should set SNI") + + require.NoError(t, req.ParseForm()) + require.NotEmpty(t, req.PostForm) + + require.Contains(t, req.PostForm, "client_assertion") + require.Equal(t, req.PostForm.Get("client_assertion"), testClientAssertion) + + require.Contains(t, req.PostForm, "client_id") + require.Equal(t, req.PostForm.Get("client_id"), fakeClientID) + + _, _ = w.Write(accessTokenRespSuccess) + }), + ) + + d := t.TempDir() + caFile := filepath.Join(d, "test-ca-file") + require.NoError(t, os.WriteFile(caFile, []byte(caData), 0600)) + + clientOptions := policy.ClientOptions{ + PerCallPolicies: []policy.Policy{ + policyFlowCheck.Policy(), + }, + } + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + ClientOptions: clientOptions, + EnableAzureProxy: true, + TenantID: fakeTenantID, + TokenFilePath: tempFile, + CustomTokenProxy: &WorkloadIdentityCustomTokenProxyOptions{ + TokenProxy: customTokenEndointServer.URL, + CAFile: caFile, + SNIName: sniName, + }, + }) + require.NoError(t, err) + require.Nil(t, clientOptions.Transport, "constructor shouldn't mutate caller's ClientOptions") + + testGetTokenSuccess(t, cred) + + require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) +} From b3ac336d3bd88bef69e6a29d1bd022f72846fc32 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Tue, 2 Dec 2025 11:51:11 +1100 Subject: [PATCH 8/8] Remove go.work.sum --- sdk/azidentity/go.work.sum | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 sdk/azidentity/go.work.sum diff --git a/sdk/azidentity/go.work.sum b/sdk/azidentity/go.work.sum deleted file mode 100644 index 7a078fbab348..000000000000 --- a/sdk/azidentity/go.work.sum +++ /dev/null @@ -1,8 +0,0 @@ -github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/keybase/dbus v0.0.0-20220506165403-5aa21ea2c23a/go.mod h1:YPNKjjE7Ubp9dTbnWvsP3HT+hYnY6TfXzubYTBeUxc8= -github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=