Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions cmd/acr-credential-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@ import (
"k8s.io/component-base/logs"
"k8s.io/klog/v2"

"sigs.k8s.io/cloud-provider-azure/cmd/acr-credential-provider/pkg/config"
"sigs.k8s.io/cloud-provider-azure/pkg/version"
)

func main() {
rand.Seed(time.Now().UnixNano())

var RegistryMirrorStr string
var (
RegistryMirrorStr string
IBSNIName string
IBDefaultClient string
IBDefaultTenant string
IBAPIIP string
)

command := &cobra.Command{
Use: "acr-credential-provider configFile",
Expand All @@ -54,7 +61,12 @@ func main() {
},
Version: version.Get().GitVersion,
RunE: func(_ *cobra.Command, args []string) error {
if err := NewCredentialProvider(args[0], RegistryMirrorStr).Run(context.TODO()); err != nil {
ibConfig, err := config.ParseIdentityBindingsConfig(IBSNIName, IBDefaultClient, IBDefaultTenant, IBAPIIP)
if err != nil {
klog.Errorf("Error parsing identity bindings config: %v", err)
return err
}
if err := NewCredentialProvider(args[0], RegistryMirrorStr, ibConfig).Run(context.TODO()); err != nil {
klog.Errorf("Error running acr credential provider: %v", err)
return err
}
Expand All @@ -68,6 +80,14 @@ func main() {
// Flags
command.Flags().StringVarP(&RegistryMirrorStr, "registry-mirror", "r", "",
"Mirror a source registry host to a target registry host, and image pull credential will be requested to the target registry host when the image is from source registry host")
command.Flags().StringVar(&IBSNIName, config.FlagIBSNIName, "",
"SNI name for identity bindings")
command.Flags().StringVar(&IBDefaultClient, config.FlagIBDefaultClient, "",
"Default Azure AD client ID for identity bindings")
command.Flags().StringVar(&IBDefaultTenant, config.FlagIBDefaultTenant, "",
"Default Azure AD tenant ID for identity bindings")
command.Flags().StringVar(&IBAPIIP, config.FlagIBAPIIP, "",
"API server IP address for identity bindings endpoint")

logs.AddFlags(command.Flags())
if err := func() error {
Expand Down
76 changes: 76 additions & 0 deletions cmd/acr-credential-provider/pkg/config/identity_bindings_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
Copyright 2021 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package config

import (
"fmt"
"net"
"strings"

"sigs.k8s.io/cloud-provider-azure/pkg/credentialprovider"
)

const (
// Flag names for identity bindings configuration
FlagIBSNIName = "ib-sni-name"
FlagIBDefaultClient = "ib-default-client-id"
FlagIBDefaultTenant = "ib-default-tenant-id"
FlagIBAPIIP = "ib-apiserver-ip"
)

// ParseIdentityBindingsConfig parses and validates identity bindings configuration from individual parameters
func ParseIdentityBindingsConfig(sniName, defaultClientID, defaultTenantID, apiServerIP string) (credentialprovider.IdentityBindingsConfig, error) {

// Validate SNI name
if sniName != "" {
if strings.HasPrefix(sniName, "https://") || strings.HasPrefix(sniName, "http://") {
return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must not contain protocol prefix (https:// or http://), got: %s",
FlagIBSNIName, sniName)
}
if apiServerIP == "" {
return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBAPIIP, FlagIBSNIName)
}
}

// Validate client ID requires SNI name
if defaultClientID != "" && sniName == "" {
return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBSNIName, FlagIBDefaultClient)
}

// Validate tenant ID requires SNI name
if defaultTenantID != "" && sniName == "" {
return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBSNIName, FlagIBDefaultTenant)
}

// Validate API server IP
if apiServerIP != "" {
if net.ParseIP(apiServerIP) == nil {
return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be a valid IP address, got: %s",
FlagIBAPIIP, apiServerIP)
}
if sniName == "" {
return credentialprovider.IdentityBindingsConfig{}, fmt.Errorf("--%s must be set when --%s is provided", FlagIBSNIName, FlagIBAPIIP)
}
}

return credentialprovider.IdentityBindingsConfig{
SNIName: sniName,
DefaultClientID: defaultClientID,
DefaultTenantID: defaultTenantID,
APIServerIP: apiServerIP,
}, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
Copyright 2021 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package config

import (
"strings"
"testing"

"sigs.k8s.io/cloud-provider-azure/pkg/credentialprovider"
)

func TestParseIdentityBindingsConfig(t *testing.T) {
tests := []struct {
name string
sniName string
defaultClientID string
defaultTenantID string
apiServerIP string
wantConfig credentialprovider.IdentityBindingsConfig
wantErr bool
errContains string
}{
{
name: "empty config",
wantConfig: credentialprovider.IdentityBindingsConfig{},
wantErr: false,
},
{
name: "valid config with all fields",
sniName: "api.example.com",
defaultClientID: "client-123",
defaultTenantID: "tenant-456",
apiServerIP: "10.0.0.1",
wantConfig: credentialprovider.IdentityBindingsConfig{
SNIName: "api.example.com",
DefaultClientID: "client-123",
DefaultTenantID: "tenant-456",
APIServerIP: "10.0.0.1",
},
wantErr: false,
},
{
name: "valid config with SNI name and API server IP only",
sniName: "api.example.com",
apiServerIP: "10.0.0.1",
wantConfig: credentialprovider.IdentityBindingsConfig{
SNIName: "api.example.com",
APIServerIP: "10.0.0.1",
},
wantErr: false,
},
{
name: "SNI name with https:// prefix",
sniName: "https://api.example.com",
apiServerIP: "10.0.0.1",
wantErr: true,
errContains: "must not contain protocol prefix",
},
{
name: "SNI name with http:// prefix",
sniName: "http://api.example.com",
apiServerIP: "10.0.0.1",
wantErr: true,
errContains: "must not contain protocol prefix",
},
{
name: "SNI name without API server IP",
sniName: "api.example.com",
wantErr: true,
errContains: "ib-apiserver-ip must be set",
},
{
name: "API server IP without SNI name",
apiServerIP: "10.0.0.1",
wantErr: true,
errContains: "ib-sni-name must be set",
},
{
name: "client ID without SNI name",
defaultClientID: "client-123",
wantErr: true,
errContains: "ib-sni-name must be set",
},
{
name: "tenant ID without SNI name",
defaultTenantID: "tenant-456",
wantErr: true,
errContains: "ib-sni-name must be set",
},
{
name: "invalid API server IP - hostname",
sniName: "api.example.com",
apiServerIP: "invalid-hostname",
wantErr: true,
errContains: "must be a valid IP address",
},
{
name: "invalid API server IP - malformed",
sniName: "api.example.com",
apiServerIP: "999.999.999.999",
wantErr: true,
errContains: "must be a valid IP address",
},
{
name: "valid IPv6 address",
sniName: "api.example.com",
apiServerIP: "2001:db8::1",
wantConfig: credentialprovider.IdentityBindingsConfig{
SNIName: "api.example.com",
APIServerIP: "2001:db8::1",
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotConfig, err := ParseIdentityBindingsConfig(tt.sniName, tt.defaultClientID, tt.defaultTenantID, tt.apiServerIP)
if (err != nil) != tt.wantErr {
t.Errorf("ParseIdentityBindingsConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
if err == nil {
t.Errorf("ParseIdentityBindingsConfig() expected error containing %q, got nil", tt.errContains)
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("ParseIdentityBindingsConfig() error = %v, want error containing %q", err, tt.errContains)
}
return
}
if gotConfig.SNIName != tt.wantConfig.SNIName {
t.Errorf("ParseIdentityBindingsConfig() SNIName = %v, want %v", gotConfig.SNIName, tt.wantConfig.SNIName)
}
if gotConfig.DefaultClientID != tt.wantConfig.DefaultClientID {
t.Errorf("ParseIdentityBindingsConfig() DefaultClientID = %v, want %v", gotConfig.DefaultClientID, tt.wantConfig.DefaultClientID)
}
if gotConfig.DefaultTenantID != tt.wantConfig.DefaultTenantID {
t.Errorf("ParseIdentityBindingsConfig() DefaultTenantID = %v, want %v", gotConfig.DefaultTenantID, tt.wantConfig.DefaultTenantID)
}
if gotConfig.APIServerIP != tt.wantConfig.APIServerIP {
t.Errorf("ParseIdentityBindingsConfig() APIServerIP = %v, want %v", gotConfig.APIServerIP, tt.wantConfig.APIServerIP)
}
})
}
}
6 changes: 4 additions & 2 deletions cmd/acr-credential-provider/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ func init() {
type ExecPlugin struct {
configFile string
RegistryMirrorStr string
IBConfig credentialprovider.IdentityBindingsConfig
plugin credentialprovider.CredentialProvider
}

// NewCredentialProvider returns an instance of execPlugin that fetches
// credentials based on the provided plugin implementing the CredentialProvider interface.
func NewCredentialProvider(configFile string, registryMirrorStr string) *ExecPlugin {
func NewCredentialProvider(configFile string, registryMirrorStr string, ibConfig credentialprovider.IdentityBindingsConfig) *ExecPlugin {
return &ExecPlugin{
configFile: configFile,
RegistryMirrorStr: registryMirrorStr,
IBConfig: ibConfig,
}
}

Expand Down Expand Up @@ -92,7 +94,7 @@ func (e *ExecPlugin) runPlugin(ctx context.Context, r io.Reader, w io.Writer, ar

if e.plugin == nil {
// acr provider plugin are decided at runtime by the request information.
e.plugin, err = credentialprovider.NewAcrProvider(request, e.RegistryMirrorStr, e.configFile)
e.plugin, err = credentialprovider.NewAcrProvider(request, e.RegistryMirrorStr, e.configFile, e.IBConfig)
if err != nil {
return err
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/acr-credential-provider/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
v1 "k8s.io/kubelet/pkg/apis/credentialprovider/v1"

"sigs.k8s.io/cloud-provider-azure/pkg/credentialprovider"
)

type fakePlugin struct {
Expand Down Expand Up @@ -94,7 +96,7 @@ func Test_runPlugin(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error when writing to temp file: %v", err)
}
p := NewCredentialProvider(configFile.Name(), "mcr.microsoft.com:fakeacrname.azurecr.io")
p := NewCredentialProvider(configFile.Name(), "mcr.microsoft.com:fakeacrname.azurecr.io", credentialprovider.IdentityBindingsConfig{})
p.plugin = &fakePlugin{}
out := &bytes.Buffer{}

Expand Down
35 changes: 20 additions & 15 deletions pkg/credentialprovider/azure_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ type acrProvider struct {
registryMirror map[string]string // Registry mirror relation: source registry -> target registry
}

type getTokenCredentialFunc func(req *v1.CredentialProviderRequest, config *providerconfig.AzureClientConfig) (azcore.TokenCredential, error)

func NewAcrProvider(req *v1.CredentialProviderRequest, registryMirrorStr string, configFile string) (CredentialProvider, error) {
func NewAcrProvider(req *v1.CredentialProviderRequest, registryMirrorStr string, configFile string, ibConfig IdentityBindingsConfig) (CredentialProvider, error) {
config, err := configloader.Load[providerconfig.AzureClientConfig](context.Background(), nil, &configloader.FileLoaderConfig{FilePath: configFile})
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
Expand All @@ -83,20 +81,27 @@ func NewAcrProvider(req *v1.CredentialProviderRequest, registryMirrorStr string,
return nil, err
}

var getTokenCredential getTokenCredentialFunc

// kubelet is responsible for checking the service account token emptiness when service account token is enabled, and only when service account token provide is enabled,
// service account token is set in the request, so we can safely check the service account token emptiness to decide which credential to use.
if len(req.ServiceAccountToken) != 0 {
klog.V(2).Infof("Using service account token to authenticate ACR for image %s", req.Image)
getTokenCredential = getServiceAccountTokenCredential
var credential azcore.TokenCredential
if ibConfig.SNIName != "" {
klog.V(2).Infof("Using identity bindings token credential for image %s", req.Image)
credential, err = GetIdentityBindingsTokenCredential(req, config, ibConfig)
if err != nil {
return nil, fmt.Errorf("failed to get identity bindings token credential for image %s: %w", req.Image, err)
}
} else if len(req.ServiceAccountToken) != 0 {
// Use service account token credential
klog.V(2).Infof("Using service account token credential for image %s", req.Image)
credential, err = getServiceAccountTokenCredential(req, config)
if err != nil {
return nil, fmt.Errorf("failed to get service account token credential for image %s: %w", req.Image, err)
}
} else {
// Use managed identity
klog.V(2).Infof("Using managed identity to authenticate ACR for image %s", req.Image)
getTokenCredential = getManagedIdentityCredential
}
credential, err := getTokenCredential(req, config)
if err != nil {
return nil, fmt.Errorf("failed to get token credential for image %s: %w", req.Image, err)
credential, err = getManagedIdentityCredential(req, config)
if err != nil {
return nil, fmt.Errorf("failed to get token credential for image %s: %w", req.Image, err)
}
}

return &acrProvider{
Expand Down
Loading