diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 8a98b8c154a4..082452004442 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -357,8 +357,8 @@ func TestDefaultAzureCredential_WorkloadIdentity(t *testing.T) { t.Setenv(azureTokenCredentials, credNameWorkloadIdentity) // these values should trigger validation errors if WorkloadIdentityCredential // tries to configure identity binding mode... - t.Setenv(customtokenproxy.AzureKubernetesCAData, "not a valid cert") - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, "http://timeout.local&fail=yes#please") + t.Setenv(customtokenproxy.EnvAzureKubernetesCAData, "not a valid cert") + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, "http://timeout.local&fail=yes#please") cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, diff --git a/sdk/azidentity/internal/customtokenproxy/configuration.go b/sdk/azidentity/internal/customtokenproxy/configuration.go new file mode 100644 index 000000000000..0dcd9840c19a --- /dev/null +++ b/sdk/azidentity/internal/customtokenproxy/configuration.go @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package customtokenproxy + +import ( + "errors" + "fmt" + "net/url" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/exported" +) + +const ( + EnvAzureKubernetesCAData = "AZURE_KUBERNETES_CA_DATA" + EnvAzureKubernetesCAFile = "AZURE_KUBERNETES_CA_FILE" + EnvAzureKubernetesSNIName = "AZURE_KUBERNETES_SNI_NAME" + EnvAzureKubernetesTokenProxy = "AZURE_KUBERNETES_TOKEN_PROXY" +) + +func readOptionsFromEnv() *exported.CustomTokenProxyOptions { + return &exported.CustomTokenProxyOptions{ + TokenProxy: os.Getenv(EnvAzureKubernetesTokenProxy), + SNIName: os.Getenv(EnvAzureKubernetesSNIName), + CAFile: os.Getenv(EnvAzureKubernetesCAFile), + CAData: os.Getenv(EnvAzureKubernetesCAData), + } +} + +func backfillOptionsFromEnv(opts *exported.CustomTokenProxyOptions) { + if opts.CAData != "" || opts.CAFile != "" || opts.SNIName != "" || opts.TokenProxy != "" { + return + } + + // only backfill if all fields are empty + *opts = *readOptionsFromEnv() +} + +func parseTokenProxyURL(endpoint string) (*url.URL, error) { + tokenProxy, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("failed to parse custom token proxy URL %q: %s", endpoint, err) + } + if tokenProxy.Scheme != "https" { + return nil, fmt.Errorf("custom token endpoint must use https scheme, got %q", tokenProxy.Scheme) + } + if tokenProxy.User != nil { + return nil, fmt.Errorf("custom token endpoint URL %q must not contain user info", tokenProxy) + } + if tokenProxy.RawQuery != "" { + return nil, fmt.Errorf("custom token endpoint URL %q must not contain a query", tokenProxy) + } + if tokenProxy.EscapedFragment() != "" { + return nil, fmt.Errorf("custom token endpoint URL %q must not contain a fragment", tokenProxy) + } + if tokenProxy.EscapedPath() == "" { + // if the path is empty, set it to "/" to avoid stripping the path from req.URL + tokenProxy.Path = "/" + } + return tokenProxy, nil +} + +var ( + errCustomEndpointSetWithoutTokenProxy = errors.New( + "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related settings are present", + ) + errCustomEndpointMultipleCASourcesSet = errors.New( + "only one of AzureKubernetesCAFile or AzureKubernetesCAData can be specified", + ) +) + +func noopConfigure(*policy.ClientOptions) { + // no-op +} + +// GetClientOptionsConfigurer returns a function that configures the client options to use the custom token proxy. +func GetClientOptionsConfigurer(opts *exported.CustomTokenProxyOptions) (func(*policy.ClientOptions), error) { + if opts == nil { + return noopConfigure, nil + } + + backfillOptionsFromEnv(opts) + + if opts.TokenProxy == "" { + // custom token proxy is not set, while other Kubernetes-related environment variables are present, + // this is likely a configuration issue so erroring out to avoid misconfiguration + if opts.SNIName != "" || opts.CAFile != "" || opts.CAData != "" { + return nil, errCustomEndpointSetWithoutTokenProxy + } + + return noopConfigure, nil + } + + tokenProxy, err := parseTokenProxyURL(opts.TokenProxy) + if err != nil { + return nil, err + } + + // CAFile and CAData are mutually exclusive, at most one can be set. + // If none of CAFile or CAData are set, the default system CA pool will be used. + if opts.CAFile != "" && opts.CAData != "" { + return nil, errCustomEndpointMultipleCASourcesSet + } + + // preload the transport + t := &transport{ + caFile: opts.CAFile, + caData: []byte(opts.CAData), + sniName: opts.SNIName, + tokenProxy: tokenProxy, + } + if _, err := t.getTokenTransporter(); err != nil { + return nil, err + } + + return func(clientOptions *policy.ClientOptions) { + clientOptions.Transport = t + }, nil +} diff --git a/sdk/azidentity/internal/customtokenproxy/configuration_test.go b/sdk/azidentity/internal/customtokenproxy/configuration_test.go new file mode 100644 index 000000000000..d20a034dfad1 --- /dev/null +++ b/sdk/azidentity/internal/customtokenproxy/configuration_test.go @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package customtokenproxy + +import ( + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/exported" +) + +func TestParseTokenProxyURL(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + endpoint string + check func(t testing.TB, u *url.URL, err error) + }{ + { + name: "valid https endpoint without path", + endpoint: "https://example.com", + check: func(t testing.TB, u *url.URL, err error) { + require.NoError(t, err) + require.Equal(t, "https", u.Scheme) + require.Equal(t, "example.com", u.Host) + require.Equal(t, "", u.RawQuery) + require.Equal(t, "", u.Fragment) + require.Equal(t, "/", u.Path, "should set path to '/' if not present") + }, + }, + { + name: "valid https endpoint with path", + endpoint: "https://example.com/token/path", + check: func(t testing.TB, u *url.URL, err error) { + require.NoError(t, err) + require.Equal(t, "/token/path", u.Path) + }, + }, + { + name: "reject non-https scheme", + endpoint: "http://example.com", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "https scheme") + }, + }, + { + name: "reject user info", + endpoint: "https://user:pass@example.com/token", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "must not contain user info") + }, + }, + { + name: "reject query params", + endpoint: "https://example.com/token?foo=bar", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "must not contain a query") + }, + }, + { + name: "reject fragment", + endpoint: "https://example.com/token#frag", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "must not contain a fragment") + }, + }, + { + name: "reject unparseable URL", + endpoint: "https://example.com/%zz", + check: func(t testing.TB, _ *url.URL, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "failed to parse custom token proxy URL") + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + u, err := parseTokenProxyURL(c.endpoint) + c.check(t, u, err) + }) + } +} + +func TestOptions_Configure(t *testing.T) { + var ( + testCAData = string(createTestCA(t)) + testCAFile = createTestCAFile(t) + ) + + tests := []struct { + Name string + Envs map[string]string + Options exported.CustomTokenProxyOptions + ClientOptions policy.ClientOptions + ExpectErr bool + AssertErr func(t testing.TB, err error) + ExpectTransport bool + }{ + { + Name: "no custom endpoint", + Envs: map[string]string{}, + Options: exported.CustomTokenProxyOptions{}, + ExpectErr: false, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with minimal settings", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint enabled with CA file + SNI", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAFile: testCAFile, + SNIName: "custom-sni.example.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint enabled with invalid CA file", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAFile: "/non/existent/path/to/custom-ca-file.pem", + }, + ExpectErr: true, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with CA file contains invalid CA data", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAFile: func() string { + t.Helper() + + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "invalid-ca-file.pem") + require.NoError(t, os.WriteFile(caFile, []byte("invalid-ca-cert"), 0600)) + return caFile + }(), + }, + ExpectErr: true, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with CA data + SNI", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: testCAData, + SNIName: "custom-sni.example.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint enabled with invalid CA data", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: string("invalid-ca-cert"), + }, + ExpectErr: true, + ExpectTransport: false, + }, + { + Name: "custom endpoint enabled with SNI", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + SNIName: "custom-sni.example.com", + }, + ExpectErr: false, + ExpectTransport: true, + }, + { + Name: "custom endpoint disabled with extra environment variables", + Options: exported.CustomTokenProxyOptions{ + SNIName: "custom-sni.example.com", + }, + ExpectErr: true, + AssertErr: func(t testing.TB, err error) { + require.ErrorIs(t, err, errCustomEndpointSetWithoutTokenProxy) + }, + }, + { + Name: "custom endpoint enabled with both CAData and CAFile", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: testCAData, + CAFile: testCAFile, + }, + ExpectErr: true, + AssertErr: func(t testing.TB, err error) { + require.ErrorIs(t, err, errCustomEndpointMultipleCASourcesSet) + }, + }, + { + Name: "custom endpoint enabled with invalid endpoint", + Options: exported.CustomTokenProxyOptions{ + // http endpoint is not allowed + TokenProxy: "http://custom-endpoint.com", + }, + ExpectErr: true, + }, + { + Name: "set by environment variables", + Envs: map[string]string{ + EnvAzureKubernetesTokenProxy: "https://custom-endpoint.com", + EnvAzureKubernetesCAFile: testCAFile, + EnvAzureKubernetesSNIName: "custom-sni.example.com", + }, + Options: exported.CustomTokenProxyOptions{}, + ExpectErr: false, + ExpectTransport: true, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + if len(tt.Envs) > 0 { + for k, v := range tt.Envs { + t.Setenv(k, v) + } + } + + mutateClientOptions, err := GetClientOptionsConfigurer(&tt.Options) + if tt.ExpectErr { + require.Error(t, err) + if tt.AssertErr != nil { + tt.AssertErr(t, err) + } + return + } + + require.NoError(t, err) + + mutateClientOptions(&tt.ClientOptions) + if tt.ExpectTransport { + require.NotNil(t, tt.ClientOptions.Transport) + require.IsType(t, &transport{}, tt.ClientOptions.Transport) + } else { + require.Nil(t, tt.ClientOptions.Transport) + } + }) + } +} + +func TestBackfillOptionsFromEnv(t *testing.T) { + tests := []struct { + Name string + Options exported.CustomTokenProxyOptions + Envs map[string]string + Expected exported.CustomTokenProxyOptions + }{ + { + Name: "empty", + Options: exported.CustomTokenProxyOptions{}, + Envs: map[string]string{}, + Expected: exported.CustomTokenProxyOptions{}, + }, + { + Name: "options field is not nil", + Options: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: "testCAData", + CAFile: "testCAFile", + SNIName: "custom-sni.example.com", + }, + Envs: map[string]string{ + EnvAzureKubernetesTokenProxy: "https://endpoint-from-env.com", + EnvAzureKubernetesCAData: "ca-data-from-env", + EnvAzureKubernetesCAFile: "ca-file-from-env", + EnvAzureKubernetesSNIName: "sni-name-from-env", + }, + Expected: exported.CustomTokenProxyOptions{ + TokenProxy: "https://custom-endpoint.com", + CAData: "testCAData", + CAFile: "testCAFile", + SNIName: "custom-sni.example.com", + }, + }, + { + Name: "options field is nil", + Options: exported.CustomTokenProxyOptions{}, + Envs: map[string]string{ + EnvAzureKubernetesTokenProxy: "https://endpoint-from-env.com", + EnvAzureKubernetesCAData: "ca-data-from-env", + EnvAzureKubernetesCAFile: "ca-file-from-env", + EnvAzureKubernetesSNIName: "sni-name-from-env", + }, + Expected: exported.CustomTokenProxyOptions{ + TokenProxy: "https://endpoint-from-env.com", + CAData: "ca-data-from-env", + CAFile: "ca-file-from-env", + SNIName: "sni-name-from-env", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + + for k, v := range tt.Envs { + t.Setenv(k, v) + } + + backfillOptionsFromEnv(&tt.Options) + require.Equal(t, tt.Expected, tt.Options) + }) + } +} diff --git a/sdk/azidentity/internal/customtokenproxy/transport.go b/sdk/azidentity/internal/customtokenproxy/transport.go index 6c0fc6244203..db4fe93905d0 100644 --- a/sdk/azidentity/internal/customtokenproxy/transport.go +++ b/sdk/azidentity/internal/customtokenproxy/transport.go @@ -13,49 +13,6 @@ import ( "net/url" "os" "time" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" -) - -const ( - AzureKubernetesCAData = "AZURE_KUBERNETES_CA_DATA" - AzureKubernetesCAFile = "AZURE_KUBERNETES_CA_FILE" - AzureKubernetesSNIName = "AZURE_KUBERNETES_SNI_NAME" - - AzureKubernetesTokenProxy = "AZURE_KUBERNETES_TOKEN_PROXY" -) - -func parseAndValidate(endpoint string) (*url.URL, error) { - tokenProxy, err := url.Parse(endpoint) - if err != nil { - return nil, fmt.Errorf("failed to parse custom token proxy URL %q: %s", endpoint, err) - } - if tokenProxy.Scheme != "https" { - return nil, fmt.Errorf("custom token endpoint must use https scheme, got %q", tokenProxy.Scheme) - } - if tokenProxy.User != nil { - return nil, fmt.Errorf("custom token endpoint URL %q must not contain user info", tokenProxy) - } - if tokenProxy.RawQuery != "" { - return nil, fmt.Errorf("custom token endpoint URL %q must not contain a query", tokenProxy) - } - if tokenProxy.EscapedFragment() != "" { - return nil, fmt.Errorf("custom token endpoint URL %q must not contain a fragment", tokenProxy) - } - if tokenProxy.EscapedPath() == "" { - // if the path is empty, set it to "/" to avoid stripping the path from req.URL - tokenProxy.Path = "/" - } - return tokenProxy, nil -} - -var ( - errCustomEndpointEnvSetWithoutTokenProxy = errors.New( - "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present", - ) - errCustomEndpointMultipleCASourcesSet = errors.New( - "only one of AZURE_KUBERNETES_CA_FILE and AZURE_KUBERNETES_CA_DATA can be specified", - ) ) func createTransport(sniName string, caPool *x509.CertPool) *http.Transport { @@ -82,49 +39,6 @@ func createTransport(sniName string, caPool *x509.CertPool) *http.Transport { return transport } -// Configure configures custom token endpoint mode if the required environment variables are present. -func Configure(clientOptions *policy.ClientOptions) error { - kubernetesTokenProxyStr := os.Getenv(AzureKubernetesTokenProxy) - - kubernetesSNIName := os.Getenv(AzureKubernetesSNIName) - kubernetesCAFile := os.Getenv(AzureKubernetesCAFile) - kubernetesCAData := os.Getenv(AzureKubernetesCAData) - - if kubernetesTokenProxyStr == "" { - // custom token proxy is not set, while other Kubernetes-related environment variables are present, - // this is likely a configuration issue so erroring out to avoid misconfiguration - if kubernetesSNIName != "" || kubernetesCAFile != "" || kubernetesCAData != "" { - return errCustomEndpointEnvSetWithoutTokenProxy - } - - return nil - } - tokenProxy, err := parseAndValidate(kubernetesTokenProxyStr) - if err != nil { - return err - } - - // CAFile and CAData are mutually exclusive, at most one can be set. - // If none of CAFile or CAData are set, the default system CA pool will be used. - if kubernetesCAFile != "" && kubernetesCAData != "" { - return errCustomEndpointMultipleCASourcesSet - } - - // preload the transport - t := &transport{ - caFile: kubernetesCAFile, - caData: []byte(kubernetesCAData), - sniName: kubernetesSNIName, - tokenProxy: tokenProxy, - } - if _, err := t.getTokenTransporter(); err != nil { - return err - } - - clientOptions.Transport = t - return nil -} - // transport redirects requests to the configured proxy. // // Lock is not needed for internal caData as this transport is called under confidentialClient's lock. diff --git a/sdk/azidentity/internal/customtokenproxy/transport_test.go b/sdk/azidentity/internal/customtokenproxy/transport_test.go index c74a11a9169b..97834a4711de 100644 --- a/sdk/azidentity/internal/customtokenproxy/transport_test.go +++ b/sdk/azidentity/internal/customtokenproxy/transport_test.go @@ -19,237 +19,9 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/stretchr/testify/require" ) -func TestParseAndValidate(t *testing.T) { - cases := []struct { - name string - endpoint string - check func(t testing.TB, u *url.URL, err error) - }{ - { - name: "valid https endpoint without path", - endpoint: "https://example.com", - check: func(t testing.TB, u *url.URL, err error) { - require.NoError(t, err) - require.Equal(t, "https", u.Scheme) - require.Equal(t, "example.com", u.Host) - require.Equal(t, "", u.RawQuery) - require.Equal(t, "", u.Fragment) - require.Equal(t, "/", u.Path, "should set path to '/' if not present") - }, - }, - { - name: "valid https endpoint with path", - endpoint: "https://example.com/token/path", - check: func(t testing.TB, u *url.URL, err error) { - require.NoError(t, err) - require.Equal(t, "/token/path", u.Path) - }, - }, - { - name: "reject non-https scheme", - endpoint: "http://example.com", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "https scheme") - }, - }, - { - name: "reject user info", - endpoint: "https://user:pass@example.com/token", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "must not contain user info") - }, - }, - { - name: "reject query params", - endpoint: "https://example.com/token?foo=bar", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "must not contain a query") - }, - }, - { - name: "reject fragment", - endpoint: "https://example.com/token#frag", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "must not contain a fragment") - }, - }, - { - name: "reject unparseable URL", - endpoint: "https://example.com/%zz", - check: func(t testing.TB, _ *url.URL, err error) { - require.Error(t, err) - require.ErrorContains(t, err, "failed to parse custom token proxy URL") - }, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - u, err := parseAndValidate(c.endpoint) - c.check(t, u, err) - }) - } -} - -func TestConfigure(t *testing.T) { - var ( - testCAData = string(createTestCA(t)) - testCAFile = createTestCAFile(t) - ) - - cases := []struct { - name string - envs map[string]string - clientOptions policy.ClientOptions - - expectErr bool - checkErr func(t testing.TB, err error) // optional check on error - expectTransport bool - }{ - { - name: "no custom endpoint", - expectErr: false, - expectTransport: false, - }, - { - name: "custom endpoint enabled with minimal settings", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint enabled with CA file + SNI", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: testCAFile, - AzureKubernetesSNIName: "custom-sni.example.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint enabled with invalid CA file", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: "/non/existent/path/to/custom-ca-file.pem", - }, - expectTransport: false, - }, - { - name: "custom endpoint enabled with CA file contains invalid CA data", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAFile: func() string { - t.Helper() - - tempDir := t.TempDir() - caFile := filepath.Join(tempDir, "invalid-ca-file.pem") - require.NoError(t, os.WriteFile(caFile, []byte("invalid-ca-cert"), 0600)) - return caFile - }(), - }, - expectTransport: false, - }, - { - name: "custom endpoint enabled with CA data + SNI", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: testCAData, - AzureKubernetesSNIName: "custom-sni.example.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint enabled with invalid CA data", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: string("invalid-ca-cert"), - }, - expectTransport: false, - }, - { - name: "custom endpoint enabled with SNI", - expectErr: false, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesSNIName: "custom-sni.example.com", - }, - expectTransport: true, - }, - { - name: "custom endpoint disabled with extra environment variables", - expectErr: true, - envs: map[string]string{ - AzureKubernetesSNIName: "custom-sni.example.com", - }, - checkErr: func(t testing.TB, err error) { - require.ErrorIs(t, err, errCustomEndpointEnvSetWithoutTokenProxy) - }, - }, - { - name: "custom endpoint enabled with both CAData and CAFile", - expectErr: true, - envs: map[string]string{ - AzureKubernetesTokenProxy: "https://custom-endpoint.com", - AzureKubernetesCAData: testCAData, - AzureKubernetesCAFile: testCAFile, - }, - checkErr: func(t testing.TB, err error) { - require.ErrorIs(t, err, errCustomEndpointMultipleCASourcesSet) - }, - }, - { - name: "custom endpoint enabled with invalid endpoint", - expectErr: true, - envs: map[string]string{ - // http endpoint is not allowed - AzureKubernetesTokenProxy: "http://custom-endpoint.com", - }, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - if len(c.envs) > 0 { - for k, v := range c.envs { - t.Setenv(k, v) - } - } - err := Configure(&c.clientOptions) - - if c.expectErr { - require.Error(t, err) - if c.checkErr != nil { - c.checkErr(t, err) - } - return - } - - require.NoError(t, err) - if c.expectTransport { - require.NotNil(t, c.clientOptions.Transport) - require.IsType(t, &transport{}, c.clientOptions.Transport) - } else { - require.Nil(t, c.clientOptions.Transport) - } - }) - } -} - // createTestCA creates a valid CA as bytes func createTestCA(t testing.TB) []byte { t.Helper() @@ -439,7 +211,6 @@ func TestGetTokenTransporter_reentry(t *testing.T) { func TestTransport_Do(t *testing.T) { mux := http.NewServeMux() testServer := httptest.NewTLSServer(mux) - ca := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: testServer.Certificate().Raw}) require.NotEmpty(t, ca) diff --git a/sdk/azidentity/internal/exported/custom_token_proxy.go b/sdk/azidentity/internal/exported/custom_token_proxy.go new file mode 100644 index 000000000000..78b0ef9057b7 --- /dev/null +++ b/sdk/azidentity/internal/exported/custom_token_proxy.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +// CustomTokenProxyOptions contains optional parameters for custom token proxy configuration. +type CustomTokenProxyOptions struct { + // CAData specifies the CA certificate data for the Kubernetes cluster. + // Corresponds to the AZURE_KUBERNETES_CA_DATA environment variable. + // At most one of CAData or CAFile should be set. + CAData string + + // CAFile specifies the path to the CA certificate file for the Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_CA_FILE environment variable. + // At most one of CAData or CAFile should be set. + CAFile string + + // SNIName specifies the name of the SNI for Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_SNI_NAME environment variable. + SNIName string + + // TokenProxy specifies the URL of the custom token proxy for the Kubernetes cluster. + // This field corresponds to the AZURE_KUBERNETES_TOKEN_PROXY environment variable. + TokenProxy string +} diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index 11b43e0904f3..c2825bc7b955 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -14,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/customtokenproxy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal/exported" ) const credNameWorkloadIdentity = "WorkloadIdentityCredential" @@ -23,12 +24,16 @@ const credNameWorkloadIdentity = "WorkloadIdentityCredential" // // [Azure Kubernetes Service documentation]: https://learn.microsoft.com/azure/aks/workload-identity-overview type WorkloadIdentityCredential struct { - assertion, file string - cred *ClientAssertionCredential - expires time.Time - mtx *sync.RWMutex + assertion, file string + getAssertionOverride func(ctx context.Context) (string, error) + cred *ClientAssertionCredential + expires time.Time + mtx *sync.RWMutex } +// WorkloadIdentityCustomTokenProxyOptions contains optional parameters for configuring WorkloadIdentity Custom Token Proxy. +type WorkloadIdentityCustomTokenProxyOptions = exported.CustomTokenProxyOptions + // WorkloadIdentityCredentialOptions contains optional parameters for WorkloadIdentityCredential. type WorkloadIdentityCredentialOptions struct { azcore.ClientOptions @@ -52,8 +57,9 @@ type WorkloadIdentityCredentialOptions struct { // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool - // EnableAzureProxy determines whether the credential reads proxy configuration from environment variables. When - // this value is true and proxy configuration isn't present or this value is false, the credential will request + // EnableAzureProxy determines whether the credential reads proxy configuration from environment variables or + // from the CustomTokenProxy field. + // When this value is true and proxy configuration isn't present or this value is false, the credential will request // tokens directly from Entra ID. // // The proxy feature is designed for applications that deploy to many clusters and clusters that host many @@ -61,12 +67,22 @@ type WorkloadIdentityCredentialOptions struct { // to set this option: https://learn.microsoft.com/azure/aks/identity-bindings-concepts EnableAzureProxy bool + // CustomTokenProxy specifies the options for the custom token proxy. + // It should not be set if EnableAzureProxy is true. + CustomTokenProxy *WorkloadIdentityCustomTokenProxyOptions + // TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID. TenantID string - // TokenFilePath is the path of a file containing a Kubernetes service account token. Defaults to the value of the - // environment variable AZURE_FEDERATED_TOKEN_FILE. + // TokenFilePath is the path of a file containing a Kubernetes service account token. + // This field is mutually exclusive with GetFederatedToken. + // If neither is specified, the credential will attempt to read the token path from the AZURE_FEDERATED_TOKEN_FILE environment variable. TokenFilePath string + + // GetFederatedToken defines an optional func to get the Kubernetes service account token. + // This field is mutually exclusive with TokenFilePath. + // If neither is specified, the credential will attempt to read the token path from the AZURE_FEDERATED_TOKEN_FILE environment variable. + GetFederatedToken func(ctx context.Context) (string, error) } // NewWorkloadIdentityCredential constructs a WorkloadIdentityCredential. Service principal configuration is read @@ -83,9 +99,12 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } } file := options.TokenFilePath - if file == "" { + if file != "" && options.GetFederatedToken != nil { + return nil, errors.New("TokenFilePath and GetFederatedToken cannot be set at the same time") + } + if file == "" && options.GetFederatedToken == nil { if file, ok = os.LookupEnv(azureFederatedTokenFile); !ok { - return nil, errors.New("no token file specified. Check pod configuration or set TokenFilePath in the options") + return nil, errors.New("no token source specified. Check pod configuration or set GetFederatedToken or TokenFilePath in the options") } } tenantID := options.TenantID @@ -94,8 +113,15 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( return nil, errors.New("no tenant ID specified. Check pod configuration or set TenantID in the options") } } + if !options.EnableAzureProxy && options.CustomTokenProxy != nil { + return nil, errors.New("CustomTokenProxy should not be set if EnableAzureProxy is false") + } - w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}} + w := WorkloadIdentityCredential{ + file: file, + getAssertionOverride: options.GetFederatedToken, + mtx: &sync.RWMutex{}, + } caco := &ClientAssertionCredentialOptions{ AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, Cache: options.Cache, @@ -104,9 +130,15 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( } if options.EnableAzureProxy { - if err := customtokenproxy.Configure(&caco.ClientOptions); err != nil { + customTokenProxyOptions := options.CustomTokenProxy + if customTokenProxyOptions == nil { + customTokenProxyOptions = &WorkloadIdentityCustomTokenProxyOptions{} + } + mutateClientOptions, err := customtokenproxy.GetClientOptionsConfigurer(customTokenProxyOptions) + if err != nil { return nil, err } + mutateClientOptions(&caco.ClientOptions) } cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, caco) @@ -130,7 +162,11 @@ func (w *WorkloadIdentityCredential) GetToken(ctx context.Context, opts policy.T // getAssertion returns the specified file's content, which is expected to be a Kubernetes service account token. // Kubernetes is responsible for updating the file as service account tokens expire. -func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, error) { +func (w *WorkloadIdentityCredential) getAssertion(ctx context.Context) (string, error) { + if w.getAssertionOverride != nil { + return w.getAssertionOverride(ctx) + } + w.mtx.RLock() if w.expires.Before(time.Now()) { // ensure only one goroutine at a time updates the assertion diff --git a/sdk/azidentity/workload_identity_test.go b/sdk/azidentity/workload_identity_test.go index 85fc680e4be0..5628061d0884 100644 --- a/sdk/azidentity/workload_identity_test.go +++ b/sdk/azidentity/workload_identity_test.go @@ -388,6 +388,86 @@ func (c *customTokenRequestPolicyFlowCheck) Validate(t testing.TB, req *http.Req require.Equal(t, c.requiredHeaderValue, req.Header.Get(c.requiredHeaderKey)) } +func TestWorkloadIdentityCredential_GetFederatedTokenOverride(t *testing.T) { + const overrideToken = "override-token" + called := new(atomic.Int32) + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + TenantID: fakeTenantID, + ClientOptions: policy.ClientOptions{ + Transport: &mockSTS{ + tokenRequestCallback: func(req *http.Request) *http.Response { + require.NoError(t, req.ParseForm()) + require.Equal(t, overrideToken, req.PostForm.Get("client_assertion")) + called.Add(1) + return nil + }, + }, + }, + GetFederatedToken: func(ctx context.Context) (string, error) { + return overrideToken, nil + }, + }) + require.NoError(t, err) + + tk, err := cred.GetToken(context.Background(), testTRO) + require.NoError(t, err) + require.Equal(t, tokenValue, tk.Token) + require.Equal(t, int32(1), called.Load()) +} + +func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithOptions(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") + if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() + + customTokenEndpointServerCalledTimes := new(atomic.Int32) + customTokenEndpointServer, caData := startTestTokenEndpointWithCAData( + t, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + customTokenEndpointServerCalledTimes.Add(1) + + policyFlowCheck.Validate(t, req) + + require.NoError(t, req.ParseForm()) + require.NotEmpty(t, req.PostForm) + + require.Contains(t, req.PostForm, "client_assertion") + require.Equal(t, req.PostForm.Get("client_assertion"), testClientAssertion) + + require.Contains(t, req.PostForm, "client_id") + require.Equal(t, req.PostForm.Get("client_id"), fakeClientID) + + _, _ = w.Write(accessTokenRespSuccess) + }), + ) + + clientOptions := policy.ClientOptions{ + PerCallPolicies: []policy.Policy{ + policyFlowCheck.Policy(), + }, + } + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + ClientOptions: clientOptions, + EnableAzureProxy: true, + TenantID: fakeTenantID, + TokenFilePath: tempFile, + CustomTokenProxy: &WorkloadIdentityCustomTokenProxyOptions{ + TokenProxy: customTokenEndpointServer.URL, + CAData: caData, + }, + }) + require.NoError(t, err) + require.Nil(t, clientOptions.Transport, "constructor shouldn't mutate caller's ClientOptions") + + testGetTokenSuccess(t, cred) + + require.Equal(t, int32(1), customTokenEndpointServerCalledTimes.Load()) +} + func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { @@ -395,11 +475,11 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) } policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() - customTokenEndointServerCalledTimes := new(atomic.Int32) - customTokenEndointServer, caData := startTestTokenEndpointWithCAData( + customTokenEndpointServerCalledTimes := new(atomic.Int32) + customTokenEndpointServer, caData := startTestTokenEndpointWithCAData( t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - customTokenEndointServerCalledTimes.Add(1) + customTokenEndpointServerCalledTimes.Add(1) policyFlowCheck.Validate(t, req) @@ -416,8 +496,8 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndointServer.URL) - t.Setenv(customtokenproxy.AzureKubernetesCAData, caData) + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, customTokenEndpointServer.URL) + t.Setenv(customtokenproxy.EnvAzureKubernetesCAData, caData) clientOptions := policy.ClientOptions{ PerCallPolicies: []policy.Policy{ @@ -436,11 +516,11 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAData(t *testing.T) testGetTokenSuccess(t, cred) - require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) + require.Equal(t, int32(1), customTokenEndpointServerCalledTimes.Load()) } func TestWorkloadIdentityCredential_CustomTokenEndpoint_InvalidSettings(t *testing.T) { - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, "invalid-token-endpoint") + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, "invalid-token-endpoint") _, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientID: fakeClientID, EnableAzureProxy: true, @@ -478,11 +558,11 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAFile(t *testing.T) }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndointServer.URL) + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, customTokenEndointServer.URL) d := t.TempDir() caFile := filepath.Join(d, "test-ca-file") require.NoError(t, os.WriteFile(caFile, []byte(caData), 0600)) - t.Setenv(customtokenproxy.AzureKubernetesCAFile, caFile) + t.Setenv(customtokenproxy.EnvAzureKubernetesCAFile, caFile) clientOptions := policy.ClientOptions{ PerCallPolicies: []policy.Policy{ @@ -504,7 +584,7 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_WithCAFile(t *testing.T) require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) } -func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup(t *testing.T) { +func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup_FromEnv(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { t.Fatalf("failed to write token file: %v", err) @@ -536,13 +616,13 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup(t *testing.T) { }), ) - t.Setenv(customtokenproxy.AzureKubernetesTokenProxy, customTokenEndointServer.URL) - t.Setenv(customtokenproxy.AzureKubernetesSNIName, sniName) + t.Setenv(customtokenproxy.EnvAzureKubernetesTokenProxy, customTokenEndointServer.URL) + t.Setenv(customtokenproxy.EnvAzureKubernetesSNIName, sniName) d := t.TempDir() caFile := filepath.Join(d, "test-ca-file") require.NoError(t, os.WriteFile(caFile, []byte(caData), 0600)) - t.Setenv(customtokenproxy.AzureKubernetesCAFile, caFile) + t.Setenv(customtokenproxy.EnvAzureKubernetesCAFile, caFile) clientOptions := policy.ClientOptions{ PerCallPolicies: []policy.Policy{ @@ -563,3 +643,64 @@ func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup(t *testing.T) { require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) } + +func TestWorkloadIdentityCredential_CustomTokenEndpoint_AKSSetup_FromOptions(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") + if err := os.WriteFile(tempFile, []byte(testClientAssertion), os.ModePerm); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + policyFlowCheck := newCustomTokenRequestPolicyFlowCheck() + sniName := "test-sni.example.com" + + customTokenEndointServerCalledTimes := new(atomic.Int32) + customTokenEndointServer, caData := startTestTokenEndpointWithCAData( + t, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + customTokenEndointServerCalledTimes.Add(1) + + policyFlowCheck.Validate(t, req) + + require.NotNil(t, req.TLS) + require.Equal(t, req.TLS.ServerName, sniName, "when SNI is set, request should set SNI") + + require.NoError(t, req.ParseForm()) + require.NotEmpty(t, req.PostForm) + + require.Contains(t, req.PostForm, "client_assertion") + require.Equal(t, req.PostForm.Get("client_assertion"), testClientAssertion) + + require.Contains(t, req.PostForm, "client_id") + require.Equal(t, req.PostForm.Get("client_id"), fakeClientID) + + _, _ = w.Write(accessTokenRespSuccess) + }), + ) + + d := t.TempDir() + caFile := filepath.Join(d, "test-ca-file") + require.NoError(t, os.WriteFile(caFile, []byte(caData), 0600)) + + clientOptions := policy.ClientOptions{ + PerCallPolicies: []policy.Policy{ + policyFlowCheck.Policy(), + }, + } + cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ + ClientID: fakeClientID, + ClientOptions: clientOptions, + EnableAzureProxy: true, + TenantID: fakeTenantID, + TokenFilePath: tempFile, + CustomTokenProxy: &WorkloadIdentityCustomTokenProxyOptions{ + TokenProxy: customTokenEndointServer.URL, + CAFile: caFile, + SNIName: sniName, + }, + }) + require.NoError(t, err) + require.Nil(t, clientOptions.Transport, "constructor shouldn't mutate caller's ClientOptions") + + testGetTokenSuccess(t, cred) + + require.Equal(t, int32(1), customTokenEndointServerCalledTimes.Load()) +}