diff --git a/cmd/acr-credential-provider/main.go b/cmd/acr-credential-provider/main.go index a61d3307ac..2292060f1a 100644 --- a/cmd/acr-credential-provider/main.go +++ b/cmd/acr-credential-provider/main.go @@ -31,13 +31,20 @@ import ( "k8s.io/component-base/logs" "k8s.io/klog/v2" + "sigs.k8s.io/cloud-provider-azure/cmd/acr-credential-provider/pkg/config" "sigs.k8s.io/cloud-provider-azure/pkg/version" ) func main() { rand.Seed(time.Now().UnixNano()) - var RegistryMirrorStr string + var ( + RegistryMirrorStr string + IBSNIName string + IBDefaultClient string + IBDefaultTenant string + IBAPIIP string + ) command := &cobra.Command{ Use: "acr-credential-provider configFile", @@ -54,7 +61,12 @@ func main() { }, Version: version.Get().GitVersion, RunE: func(_ *cobra.Command, args []string) error { - if err := NewCredentialProvider(args[0], RegistryMirrorStr).Run(context.TODO()); err != nil { + ibConfig, err := config.ParseIdentityBindingsConfig(IBSNIName, IBDefaultClient, IBDefaultTenant, IBAPIIP) + if err != nil { + klog.Errorf("Error parsing identity bindings config: %v", err) + return err + } + if err := NewCredentialProvider(args[0], RegistryMirrorStr, ibConfig).Run(context.TODO()); err != nil { klog.Errorf("Error running acr credential provider: %v", err) return err } @@ -68,6 +80,14 @@ func main() { // Flags command.Flags().StringVarP(&RegistryMirrorStr, "registry-mirror", "r", "", "Mirror a source registry host to a target registry host, and image pull credential will be requested to the target registry host when the image is from source registry host") + command.Flags().StringVar(&IBSNIName, config.FlagIBSNIName, "", + "SNI name for identity bindings") + command.Flags().StringVar(&IBDefaultClient, config.FlagIBDefaultClient, "", + "Default Azure AD client ID for identity bindings") + command.Flags().StringVar(&IBDefaultTenant, config.FlagIBDefaultTenant, "", + "Default Azure AD tenant ID for identity bindings") + command.Flags().StringVar(&IBAPIIP, config.FlagIBAPIIP, "", + "API server IP address for identity bindings endpoint") logs.AddFlags(command.Flags()) if err := func() error { diff --git a/cmd/acr-credential-provider/pkg/config/identity_bindings_config.go b/cmd/acr-credential-provider/pkg/config/identity_bindings_config.go new file mode 100644 index 0000000000..d9f9e92bbc --- /dev/null +++ b/cmd/acr-credential-provider/pkg/config/identity_bindings_config.go @@ -0,0 +1,76 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "fmt" + "net" + "strings" + + "sigs.k8s.io/cloud-provider-azure/pkg/credentialprovider" +) + +const ( + // Flag names for identity bindings configuration + FlagIBSNIName = "ib-sni-name" + FlagIBDefaultClient = "ib-default-client-id" + FlagIBDefaultTenant = "ib-default-tenant-id" + FlagIBAPIIP = "ib-apiserver-ip" +) + +// ParseIdentityBindingsConfig parses and validates identity bindings configuration from individual parameters +func ParseIdentityBindingsConfig(sniName, defaultClientID, defaultTenantID, apiServerIP string) (credentialprovider.IdentityBindingsConfig, error) { + + // Validate SNI name + if sniName != "" { + if strings.HasPrefix(sniName, "https://") || strings.HasPrefix(sniName, "http://") { + return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must not contain protocol prefix (https:// or http://), got: %s", + FlagIBSNIName, sniName) + } + if apiServerIP == "" { + return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBAPIIP, FlagIBSNIName) + } + } + + // Validate client ID requires SNI name + if defaultClientID != "" && sniName == "" { + return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBSNIName, FlagIBDefaultClient) + } + + // Validate tenant ID requires SNI name + if defaultTenantID != "" && sniName == "" { + return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBSNIName, FlagIBDefaultTenant) + } + + // Validate API server IP + if apiServerIP != "" { + if net.ParseIP(apiServerIP) == nil { + return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be a valid IP address, got: %s", + FlagIBAPIIP, apiServerIP) + } + if sniName == "" { + return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBSNIName, FlagIBAPIIP) + } + } + + return credentialprovider.IdentityBindingsConfig{ + SNIName: sniName, + DefaultClientID: defaultClientID, + DefaultTenantID: defaultTenantID, + APIServerIP: apiServerIP, + }, nil +} diff --git a/cmd/acr-credential-provider/pkg/config/identity_bindings_config_test.go b/cmd/acr-credential-provider/pkg/config/identity_bindings_config_test.go new file mode 100644 index 0000000000..135aee9012 --- /dev/null +++ b/cmd/acr-credential-provider/pkg/config/identity_bindings_config_test.go @@ -0,0 +1,159 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "strings" + "testing" + + "sigs.k8s.io/cloud-provider-azure/pkg/credentialprovider" +) + +func TestParseIdentityBindingsConfig(t *testing.T) { + tests := []struct { + name string + sniName string + defaultClientID string + defaultTenantID string + apiServerIP string + wantConfig credentialprovider.IdentityBindingsConfig + wantErr bool + errContains string + }{ + { + name: "empty config", + wantConfig: credentialprovider.IdentityBindingsConfig{}, + wantErr: false, + }, + { + name: "valid config with all fields", + sniName: "api.example.com", + defaultClientID: "client-123", + defaultTenantID: "tenant-456", + apiServerIP: "10.0.0.1", + wantConfig: credentialprovider.IdentityBindingsConfig{ + SNIName: "api.example.com", + DefaultClientID: "client-123", + DefaultTenantID: "tenant-456", + APIServerIP: "10.0.0.1", + }, + wantErr: false, + }, + { + name: "valid config with SNI name and API server IP only", + sniName: "api.example.com", + apiServerIP: "10.0.0.1", + wantConfig: credentialprovider.IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "10.0.0.1", + }, + wantErr: false, + }, + { + name: "SNI name with https:// prefix", + sniName: "https://api.example.com", + apiServerIP: "10.0.0.1", + wantErr: true, + errContains: "must not contain protocol prefix", + }, + { + name: "SNI name with http:// prefix", + sniName: "http://api.example.com", + apiServerIP: "10.0.0.1", + wantErr: true, + errContains: "must not contain protocol prefix", + }, + { + name: "SNI name without API server IP", + sniName: "api.example.com", + wantErr: true, + errContains: "ib-apiserver-ip must be set", + }, + { + name: "API server IP without SNI name", + apiServerIP: "10.0.0.1", + wantErr: true, + errContains: "ib-sni-name must be set", + }, + { + name: "client ID without SNI name", + defaultClientID: "client-123", + wantErr: true, + errContains: "ib-sni-name must be set", + }, + { + name: "tenant ID without SNI name", + defaultTenantID: "tenant-456", + wantErr: true, + errContains: "ib-sni-name must be set", + }, + { + name: "invalid API server IP - hostname", + sniName: "api.example.com", + apiServerIP: "invalid-hostname", + wantErr: true, + errContains: "must be a valid IP address", + }, + { + name: "invalid API server IP - malformed", + sniName: "api.example.com", + apiServerIP: "999.999.999.999", + wantErr: true, + errContains: "must be a valid IP address", + }, + { + name: "valid IPv6 address", + sniName: "api.example.com", + apiServerIP: "2001:db8::1", + wantConfig: credentialprovider.IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "2001:db8::1", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotConfig, err := ParseIdentityBindingsConfig(tt.sniName, tt.defaultClientID, tt.defaultTenantID, tt.apiServerIP) + if (err != nil) != tt.wantErr { + t.Errorf("ParseIdentityBindingsConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("ParseIdentityBindingsConfig() expected error containing %q, got nil", tt.errContains) + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("ParseIdentityBindingsConfig() error = %v, want error containing %q", err, tt.errContains) + } + return + } + if gotConfig.SNIName != tt.wantConfig.SNIName { + t.Errorf("ParseIdentityBindingsConfig() SNIName = %v, want %v", gotConfig.SNIName, tt.wantConfig.SNIName) + } + if gotConfig.DefaultClientID != tt.wantConfig.DefaultClientID { + t.Errorf("ParseIdentityBindingsConfig() DefaultClientID = %v, want %v", gotConfig.DefaultClientID, tt.wantConfig.DefaultClientID) + } + if gotConfig.DefaultTenantID != tt.wantConfig.DefaultTenantID { + t.Errorf("ParseIdentityBindingsConfig() DefaultTenantID = %v, want %v", gotConfig.DefaultTenantID, tt.wantConfig.DefaultTenantID) + } + if gotConfig.APIServerIP != tt.wantConfig.APIServerIP { + t.Errorf("ParseIdentityBindingsConfig() APIServerIP = %v, want %v", gotConfig.APIServerIP, tt.wantConfig.APIServerIP) + } + }) + } +} diff --git a/cmd/acr-credential-provider/plugin.go b/cmd/acr-credential-provider/plugin.go index e64597cdd9..0b79ffa397 100644 --- a/cmd/acr-credential-provider/plugin.go +++ b/cmd/acr-credential-provider/plugin.go @@ -46,15 +46,17 @@ func init() { type ExecPlugin struct { configFile string RegistryMirrorStr string + IBConfig credentialprovider.IdentityBindingsConfig plugin credentialprovider.CredentialProvider } // NewCredentialProvider returns an instance of execPlugin that fetches // credentials based on the provided plugin implementing the CredentialProvider interface. -func NewCredentialProvider(configFile string, registryMirrorStr string) *ExecPlugin { +func NewCredentialProvider(configFile string, registryMirrorStr string, ibConfig credentialprovider.IdentityBindingsConfig) *ExecPlugin { return &ExecPlugin{ configFile: configFile, RegistryMirrorStr: registryMirrorStr, + IBConfig: ibConfig, } } @@ -92,7 +94,7 @@ func (e *ExecPlugin) runPlugin(ctx context.Context, r io.Reader, w io.Writer, ar if e.plugin == nil { // acr provider plugin are decided at runtime by the request information. - e.plugin, err = credentialprovider.NewAcrProvider(request, e.RegistryMirrorStr, e.configFile) + e.plugin, err = credentialprovider.NewAcrProvider(request, e.RegistryMirrorStr, e.configFile, e.IBConfig) if err != nil { return err } diff --git a/cmd/acr-credential-provider/plugin_test.go b/cmd/acr-credential-provider/plugin_test.go index 85a794f184..91c7b4e48c 100644 --- a/cmd/acr-credential-provider/plugin_test.go +++ b/cmd/acr-credential-provider/plugin_test.go @@ -26,6 +26,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1" + + "sigs.k8s.io/cloud-provider-azure/pkg/credentialprovider" ) type fakePlugin struct { @@ -94,7 +96,7 @@ func Test_runPlugin(t *testing.T) { if err != nil { t.Fatalf("Unexpected error when writing to temp file: %v", err) } - p := NewCredentialProvider(configFile.Name(), "mcr.microsoft.com:fakeacrname.azurecr.io") + p := NewCredentialProvider(configFile.Name(), "mcr.microsoft.com:fakeacrname.azurecr.io", credentialprovider.IdentityBindingsConfig{}) p.plugin = &fakePlugin{} out := &bytes.Buffer{} diff --git a/pkg/credentialprovider/azure_credentials.go b/pkg/credentialprovider/azure_credentials.go index 70137e52c8..004cf992bf 100644 --- a/pkg/credentialprovider/azure_credentials.go +++ b/pkg/credentialprovider/azure_credentials.go @@ -66,9 +66,7 @@ type acrProvider struct { registryMirror map[string]string // Registry mirror relation: source registry -> target registry } -type getTokenCredentialFunc func(req *v1.CredentialProviderRequest, config *providerconfig.AzureClientConfig) (azcore.TokenCredential, error) - -func NewAcrProvider(req *v1.CredentialProviderRequest, registryMirrorStr string, configFile string) (CredentialProvider, error) { +func NewAcrProvider(req *v1.CredentialProviderRequest, registryMirrorStr string, configFile string, ibConfig IdentityBindingsConfig) (CredentialProvider, error) { config, err := configloader.Load[providerconfig.AzureClientConfig](context.Background(), nil, &configloader.FileLoaderConfig{FilePath: configFile}) if err != nil { return nil, fmt.Errorf("failed to load config: %w", err) @@ -83,20 +81,27 @@ func NewAcrProvider(req *v1.CredentialProviderRequest, registryMirrorStr string, return nil, err } - var getTokenCredential getTokenCredentialFunc - - // kubelet is responsible for checking the service account token emptiness when service account token is enabled, and only when service account token provide is enabled, - // service account token is set in the request, so we can safely check the service account token emptiness to decide which credential to use. - if len(req.ServiceAccountToken) != 0 { - klog.V(2).Infof("Using service account token to authenticate ACR for image %s", req.Image) - getTokenCredential = getServiceAccountTokenCredential + var credential azcore.TokenCredential + if ibConfig.SNIName != "" { + klog.V(2).Infof("Using identity bindings token credential for image %s", req.Image) + credential, err = GetIdentityBindingsTokenCredential(req, config, ibConfig) + if err != nil { + return nil, fmt.Errorf("failed to get identity bindings token credential for image %s: %w", req.Image, err) + } + } else if len(req.ServiceAccountToken) != 0 { + // Use service account token credential + klog.V(2).Infof("Using service account token credential for image %s", req.Image) + credential, err = getServiceAccountTokenCredential(req, config) + if err != nil { + return nil, fmt.Errorf("failed to get service account token credential for image %s: %w", req.Image, err) + } } else { + // Use managed identity klog.V(2).Infof("Using managed identity to authenticate ACR for image %s", req.Image) - getTokenCredential = getManagedIdentityCredential - } - credential, err := getTokenCredential(req, config) - if err != nil { - return nil, fmt.Errorf("failed to get token credential for image %s: %w", req.Image, err) + credential, err = getManagedIdentityCredential(req, config) + if err != nil { + return nil, fmt.Errorf("failed to get token credential for image %s: %w", req.Image, err) + } } return &acrProvider{ diff --git a/pkg/credentialprovider/azure_credentials_test.go b/pkg/credentialprovider/azure_credentials_test.go index 9d0bb9367d..663560c8ec 100644 --- a/pkg/credentialprovider/azure_credentials_test.go +++ b/pkg/credentialprovider/azure_credentials_test.go @@ -65,6 +65,7 @@ func TestGetCredentials(t *testing.T) { }, "", configFile.Name(), + IdentityBindingsConfig{}, ) if err != nil { @@ -166,6 +167,7 @@ func TestGetCredentialsConfig(t *testing.T) { }, "", configFile.Name(), + IdentityBindingsConfig{}, ) if err != nil && !test.expectError { t.Fatalf("Unexpected error when creating new acr provider: %v", err) @@ -208,6 +210,7 @@ func TestProcessImageWithMirrorMapping(t *testing.T) { }, "mcr.microsoft.com:abc.azurecr.io", configFile.Name(), + IdentityBindingsConfig{}, ) assert.Nilf(t, err, "Unexpected error when creating new acr provider") @@ -264,6 +267,7 @@ func TestParseACRLoginServerFromImage(t *testing.T) { }, "mcr.microsoft.com:abc.azurecr.io", configFile.Name(), + IdentityBindingsConfig{}, ) if err != nil { t.Fatalf("Unexpected error when creating new acr provider: %v", err) @@ -402,7 +406,7 @@ func TestNewAcrProvider_WithEmptyServiceAccountToken(t *testing.T) { ServiceAccountToken: "", // Empty token } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.NoError(t, err) assert.NotNil(t, provider) @@ -433,7 +437,7 @@ func TestNewAcrProvider_WithServiceAccountToken(t *testing.T) { }, } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.NoError(t, err) assert.NotNil(t, provider) @@ -463,7 +467,7 @@ func TestNewAcrProvider_WithServiceAccountToken_MissingClientIDAnnotation(t *tes }, } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.Error(t, err) assert.Nil(t, provider) assert.Contains(t, err.Error(), "client id annotation") @@ -491,7 +495,7 @@ func TestNewAcrProvider_WithServiceAccountToken_MissingTenantIDAnnotation(t *tes }, } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.Error(t, err) assert.Nil(t, provider) assert.Contains(t, err.Error(), "tenant id annotation") @@ -519,7 +523,7 @@ func TestNewAcrProvider_WithServiceAccountToken_EmptyClientID(t *testing.T) { }, } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.Error(t, err) assert.Nil(t, provider) assert.Contains(t, err.Error(), "client id annotation") @@ -547,7 +551,7 @@ func TestNewAcrProvider_WithServiceAccountToken_EmptyTenantID(t *testing.T) { }, } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.Error(t, err) assert.Nil(t, provider) assert.Contains(t, err.Error(), "tenant id annotation") @@ -570,7 +574,7 @@ func TestNewAcrProvider_InvalidConfig(t *testing.T) { ServiceAccountToken: "", } - provider, err := NewAcrProvider(req, "", configFile.Name()) + provider, err := NewAcrProvider(req, "", configFile.Name(), IdentityBindingsConfig{}) assert.Error(t, err) assert.Nil(t, provider) assert.Contains(t, err.Error(), "failed to load config") diff --git a/pkg/credentialprovider/identity_bindings_config.go b/pkg/credentialprovider/identity_bindings_config.go new file mode 100644 index 0000000000..b473bb1703 --- /dev/null +++ b/pkg/credentialprovider/identity_bindings_config.go @@ -0,0 +1,25 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentialprovider + +// IdentityBindingsConfig contains configuration for identity bindings based authentication +type IdentityBindingsConfig struct { + SNIName string + DefaultClientID string + DefaultTenantID string + APIServerIP string +} diff --git a/pkg/credentialprovider/identity_bindings_credentials.go b/pkg/credentialprovider/identity_bindings_credentials.go new file mode 100644 index 0000000000..ea40cdc2b0 --- /dev/null +++ b/pkg/credentialprovider/identity_bindings_credentials.go @@ -0,0 +1,267 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentialprovider + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + providerconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "k8s.io/klog/v2" + v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1" +) + +const ( + // Kubernetes certificate path + KubernetesCACertPath = "/etc/kubernetes/certs/ca.crt" +) + +// identityBindingsTokenCredential implements azcore.TokenCredential interface +// using identity bindings token exchange +type identityBindingsTokenCredential struct { + token string + clientID string + // tenantID is reserved for future SDK compatibility and may be used in token endpoint construction. + tenantID string + config *providerconfig.AzureClientConfig + ibConfig IdentityBindingsConfig + endpoint string + transport *http.Transport +} + +// tokenResponse represents the response from identity bindings token exchange +type tokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +// createTransport creates an HTTP transport with custom CA +// The transport uses a custom dialer that resolves the SNI name to the configured API server IP +func createTransport(sniName string, apiServerIP string, caPool *x509.CertPool) *http.Transport { + transport := http.DefaultTransport.(*http.Transport).Clone() + // reset Proxy to avoid using environment proxy settings + transport.Proxy = nil + + // Custom dialer that resolves the SNI hostname to the fixed API server IP + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + // Extract port from addr (format is "host:port") + _, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("failed to parse address %s: %w", addr, err) + } + + // Always connect to the configured API server IP + fixedAddr := net.JoinHostPort(apiServerIP, port) + klog.V(5).Infof("Identity bindings: resolving %s to %s", addr, fixedAddr) + + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + return dialer.DialContext(ctx, network, fixedAddr) + } + + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, // #nosec G402 + } + } + transport.TLSClientConfig.ServerName = sniName + // Explicitly set minimum TLS version to TLS 1.2 for security + transport.TLSClientConfig.MinVersion = tls.VersionTLS12 + + // Set custom CA pool if provided + if caPool != nil { + transport.TLSClientConfig.RootCAs = caPool + } + + return transport +} + +// getTransport provides the transport to use for the request +func (c *identityBindingsTokenCredential) getTransport() (*http.Transport, error) { + // Return existing transport if already created + if c.transport != nil { + return c.transport, nil + } + + // Read CA file + b, err := os.ReadFile(KubernetesCACertPath) + if err != nil { + return nil, fmt.Errorf("read CA file %q: %w", KubernetesCACertPath, err) + } + if len(b) == 0 { + return nil, fmt.Errorf("CA file %q is empty", KubernetesCACertPath) + } + + // Create CA pool + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(b) { + return nil, fmt.Errorf("parse CA file %q: no valid certificates found", KubernetesCACertPath) + } + + // Create and cache transport + c.transport = createTransport(c.ibConfig.SNIName, c.ibConfig.APIServerIP, caPool) + + return c.transport, nil +} + +// GetToken retrieves an access token using identity bindings token exchange +func (c *identityBindingsTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + // The scope should be exactly one value in format "https://management.azure.com/.default" + // or "https://containerregistry.azure.net/.default" + if len(opts.Scopes) != 1 { + return azcore.AccessToken{}, fmt.Errorf("expected exactly one scope, got %d", len(opts.Scopes)) + } + + scope := opts.Scopes[0] + + // Use stored client assertion token + clientAssertion := c.token + if clientAssertion == "" { + return azcore.AccessToken{}, fmt.Errorf("service account token not found") + } + + // Use stored client ID + clientID := c.clientID + if clientID == "" { + return azcore.AccessToken{}, fmt.Errorf("client ID not configured") + } + + // Prepare form data + formData := url.Values{} + formData.Set("grant_type", "client_credentials") + formData.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + formData.Set("scope", scope) + formData.Set("client_assertion", clientAssertion) + formData.Set("client_id", clientID) + + // Create request + req, err := http.NewRequestWithContext(ctx, "POST", c.endpoint, strings.NewReader(formData.Encode())) + if err != nil { + return azcore.AccessToken{}, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + klog.V(4).Infof("Requesting token from identity bindings endpoint: %s with scope: %s", c.endpoint, scope) + + // Get transport (handles CA rotation) + transport, err := c.getTransport() + if err != nil { + return azcore.AccessToken{}, fmt.Errorf("failed to get transport: %w", err) + } + + // Execute request + httpClient := &http.Client{Transport: transport} + resp, err := httpClient.Do(req) + if err != nil { + return azcore.AccessToken{}, fmt.Errorf("failed to execute token request: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return azcore.AccessToken{}, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return azcore.AccessToken{}, fmt.Errorf("token request failed with status %d", resp.StatusCode) + } + + // Parse response + var tokenResp tokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return azcore.AccessToken{}, fmt.Errorf("failed to parse token response: %w", err) + } + + expiresOn := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + klog.V(4).Infof("Successfully obtained token from identity bindings, expires at: %s", expiresOn) + + return azcore.AccessToken{ + Token: tokenResp.AccessToken, + ExpiresOn: expiresOn, + }, nil +} + +func GetIdentityBindingsTokenCredential(req *v1.CredentialProviderRequest, config *providerconfig.AzureClientConfig, ibConfig IdentityBindingsConfig) (azcore.TokenCredential, error) { + klog.V(2).Infof("Using identity bindings token credential for image %s", req.Image) + + // Get SNI name from config + sniName := ibConfig.SNIName + if sniName == "" { + return nil, fmt.Errorf("SNI name not provided in identity bindings config") + } + + // Get API server IP from config + apiServerIP := ibConfig.APIServerIP + if apiServerIP == "" { + return nil, fmt.Errorf("API server IP not provided in identity bindings config") + } + + // Get service account token + token := req.ServiceAccountToken + if token == "" { + return nil, fmt.Errorf("service account token not found in request") + } + + // Resolve client ID from annotation or use default + var clientID string + if id, ok := req.ServiceAccountAnnotations[clientIDAnnotation]; ok { + clientID = id + } else { + clientID = ibConfig.DefaultClientID + } + if clientID == "" { + return nil, fmt.Errorf("client ID not found in service account annotations (checked %s) and no default client ID configured", + clientIDAnnotation) + } + + // Resolve tenant ID from annotation or use default + var tenantID string + if id, ok := req.ServiceAccountAnnotations[tenantIDAnnotation]; ok { + tenantID = id + } else { + tenantID = ibConfig.DefaultTenantID + } + + // Build endpoint URL + endpoint := "https://" + sniName + + return &identityBindingsTokenCredential{ + token: token, + clientID: clientID, + tenantID: tenantID, + config: config, + ibConfig: ibConfig, + endpoint: endpoint, + }, nil +} diff --git a/pkg/credentialprovider/identity_bindings_credentials_test.go b/pkg/credentialprovider/identity_bindings_credentials_test.go new file mode 100644 index 0000000000..c2f3d3a10b --- /dev/null +++ b/pkg/credentialprovider/identity_bindings_credentials_test.go @@ -0,0 +1,441 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package credentialprovider + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "fmt" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1" + + providerconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" +) + +func TestGetIdentityBindingsTokenCredential(t *testing.T) { + tests := []struct { + name string + ibConfig IdentityBindingsConfig + wantErr bool + errContains string + }{ + { + name: "valid config", + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "10.0.0.1", + }, + wantErr: false, + }, + { + name: "missing SNI name", + ibConfig: IdentityBindingsConfig{ + APIServerIP: "10.0.0.1", + }, + wantErr: true, + errContains: "SNI name not provided", + }, + { + name: "missing API server IP", + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + }, + wantErr: true, + errContains: "API server IP not provided", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &v1.CredentialProviderRequest{ + Image: "test.azurecr.io/test:latest", + ServiceAccountToken: "test-sa-token", + ServiceAccountAnnotations: map[string]string{ + clientIDAnnotation: "test-client-123", + }, + } + config := &providerconfig.AzureClientConfig{} + + cred, err := GetIdentityBindingsTokenCredential(req, config, tt.ibConfig) + if (err != nil) != tt.wantErr { + t.Errorf("GetIdentityBindingsTokenCredential() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error = %v, want error containing %q", err, tt.errContains) + } + return + } + if cred == nil { + t.Error("expected non-nil credential") + } + }) + } +} + +func TestCreateTransport(t *testing.T) { + // Set HTTPS_PROXY environment variable to verify it's ignored + originalHTTPSProxy := os.Getenv("HTTPS_PROXY") + defer func() { + if originalHTTPSProxy != "" { + os.Setenv("HTTPS_PROXY", originalHTTPSProxy) + } else { + os.Unsetenv("HTTPS_PROXY") + } + }() + + // Set proxy environment variables + os.Setenv("HTTPS_PROXY", "http://proxy.example.com:8080") + + // Generate a self-signed certificate for testing + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "api.example.com", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"api.example.com"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + } + + caPool := x509.NewCertPool() + caPool.AddCert(cert) + + tests := []struct { + name string + sniName string + apiServerIP string + caPool *x509.CertPool + }{ + { + name: "with CA pool", + sniName: "api.example.com", + apiServerIP: "10.0.0.1", + caPool: caPool, + }, + { + name: "without CA pool", + sniName: "api.example.com", + apiServerIP: "10.0.0.1", + caPool: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := createTransport(tt.sniName, tt.apiServerIP, tt.caPool) + if transport == nil { + t.Error("expected non-nil transport") + return + } + if transport.TLSClientConfig == nil { + t.Error("expected non-nil TLSClientConfig") + return + } + if transport.TLSClientConfig.ServerName != tt.sniName { + t.Errorf("ServerName = %v, want %v", transport.TLSClientConfig.ServerName, tt.sniName) + } + if tt.caPool != nil { + if transport.TLSClientConfig.RootCAs == nil { + t.Error("expected non-nil RootCAs when caPool provided") + } + } + if transport.DialContext == nil { + t.Error("expected non-nil DialContext") + } + // Verify that Proxy is explicitly set to nil to avoid using environment proxy settings + if transport.Proxy != nil { + t.Error("expected Proxy to be nil to bypass environment proxy settings") + } + }) + } +} + +func TestIdentityBindingsTokenCredential_GetToken(t *testing.T) { + // Create a test server + mux := http.NewServeMux() + var formDataReceived map[string]string + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + formDataReceived = make(map[string]string) + for key := range r.Form { + formDataReceived[key] = r.Form.Get(key) + } + + resp := tokenResponse{ + AccessToken: "test-token", + ExpiresIn: 3600, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + return + } + }) + + server := httptest.NewUnstartedServer(mux) + + // Generate certificate for TLS + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "api.example.com", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"api.example.com"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + cert := tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: priv, + } + + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, // #nosec G402 + } + server.StartTLS() + defer server.Close() + + // Parse server URL to get the port + serverURL := server.URL + _, port, _ := net.SplitHostPort(strings.TrimPrefix(serverURL, "https://")) + + // Create CA pool from the server's certificate + parsedCert, _ := x509.ParseCertificate(certDER) + caPool := x509.NewCertPool() + caPool.AddCert(parsedCert) + + tests := []struct { + name string + req *v1.CredentialProviderRequest + ibConfig IdentityBindingsConfig + scopes []string + wantErr bool + errContains string + checkFormData bool + expectedGrantType string + }{ + { + name: "successful token retrieval with client ID from annotation", + req: &v1.CredentialProviderRequest{ + Image: "test.azurecr.io/test:latest", + ServiceAccountToken: "test-sa-token", + ServiceAccountAnnotations: map[string]string{ + clientIDAnnotation: "client-123", + }, + }, + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "127.0.0.1", + }, + scopes: []string{"https://containerregistry.azure.net/.default"}, + wantErr: false, + checkFormData: true, + expectedGrantType: "client_credentials", + }, + { + name: "successful token retrieval with default client ID", + req: &v1.CredentialProviderRequest{ + Image: "test.azurecr.io/test:latest", + ServiceAccountToken: "test-sa-token", + ServiceAccountAnnotations: map[string]string{}, + }, + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "127.0.0.1", + DefaultClientID: "default-client-456", + }, + scopes: []string{"https://containerregistry.azure.net/.default"}, + wantErr: false, + checkFormData: true, + expectedGrantType: "client_credentials", + }, + { + name: "missing service account token", + req: &v1.CredentialProviderRequest{ + Image: "test.azurecr.io/test:latest", + ServiceAccountToken: "", + ServiceAccountAnnotations: map[string]string{}, + }, + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "127.0.0.1", + }, + scopes: []string{"https://containerregistry.azure.net/.default"}, + wantErr: true, + errContains: "service account token not found", + }, + { + name: "missing client ID", + req: &v1.CredentialProviderRequest{ + Image: "test.azurecr.io/test:latest", + ServiceAccountToken: "test-sa-token", + ServiceAccountAnnotations: map[string]string{}, + }, + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "127.0.0.1", + }, + scopes: []string{"https://containerregistry.azure.net/.default"}, + wantErr: true, + errContains: "client ID not configured", + }, + { + name: "invalid scope count", + req: &v1.CredentialProviderRequest{ + Image: "test.azurecr.io/test:latest", + ServiceAccountToken: "test-sa-token", + ServiceAccountAnnotations: map[string]string{ + clientIDAnnotation: "client-123", + }, + }, + ibConfig: IdentityBindingsConfig{ + SNIName: "api.example.com", + APIServerIP: "127.0.0.1", + }, + scopes: []string{"scope1", "scope2"}, + wantErr: true, + errContains: "expected exactly one scope", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset captured data + formDataReceived = nil + + endpoint := fmt.Sprintf("https://api.example.com:%s", port) + transport := createTransport(tt.ibConfig.SNIName, tt.ibConfig.APIServerIP, caPool) + + // Determine client ID from annotation or default + var clientID string + if id, ok := tt.req.ServiceAccountAnnotations[clientIDAnnotation]; ok { + clientID = id + } else { + clientID = tt.ibConfig.DefaultClientID + } + + // Determine tenant ID from annotation or default + var tenantID string + if id, ok := tt.req.ServiceAccountAnnotations[tenantIDAnnotation]; ok { + tenantID = id + } else { + tenantID = tt.ibConfig.DefaultTenantID + } + + cred := &identityBindingsTokenCredential{ + token: tt.req.ServiceAccountToken, + clientID: clientID, + tenantID: tenantID, + config: &providerconfig.AzureClientConfig{}, + ibConfig: tt.ibConfig, + endpoint: endpoint, + transport: transport, + } + + ctx := context.Background() + opts := policy.TokenRequestOptions{ + Scopes: tt.scopes, + } + + token, err := cred.GetToken(ctx, opts) + if (err != nil) != tt.wantErr { + t.Errorf("GetToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error = %v, want error containing %q", err, tt.errContains) + } + return + } + + if token.Token != "test-token" { + t.Errorf("token = %v, want %v", token.Token, "test-token") + } + + if tt.checkFormData && formDataReceived != nil { + if formDataReceived["grant_type"] != tt.expectedGrantType { + t.Errorf("grant_type = %v, want %v", formDataReceived["grant_type"], tt.expectedGrantType) + } + if formDataReceived["client_assertion"] != tt.req.ServiceAccountToken { + t.Errorf("client_assertion = %v, want %v", formDataReceived["client_assertion"], tt.req.ServiceAccountToken) + } + if formDataReceived["scope"] != tt.scopes[0] { + t.Errorf("scope = %v, want %v", formDataReceived["scope"], tt.scopes[0]) + } + } + }) + } +}