Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 7 additions & 48 deletions client/cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ import (
"fmt"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"

"github.com/Azure/aks-secure-tls-bootstrap/client/internal/bootstrap"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

var bootstrapConfig = new(bootstrap.Config)
Expand Down Expand Up @@ -67,14 +66,16 @@ func main() {
}

func run(ctx context.Context) int {
logger, finalErr := configureLogging(logFile, verbose)
logger, flush, finalErr := log.NewProductionLogger(logFile, verbose)
if finalErr != nil {
fmt.Printf("unable to construct zap logger: %s\n", finalErr)
return 1
}
defer flush(logger)
defer flush()

bootstrapClient, finalErr := bootstrap.NewClient(logger)
ctx = log.WithLogger(telemetry.WithTracing(ctx), logger)

bootstrapClient, finalErr := bootstrap.NewClient(ctx)
if finalErr != nil {
fmt.Printf("unable to construct bootstrap client: %s\n", finalErr)
return 1
Expand All @@ -84,8 +85,7 @@ func run(ctx context.Context) int {
bootstrapDeadline := bootstrapStartTime.Add(bootstrapConfig.Deadline)
logger.Info("set bootstrap deadline", zap.Time("deadline", bootstrapDeadline))

bootstrapCtx := telemetry.WithTracer(ctx, telemetry.NewTracer())
bootstrapCtx, cancel := context.WithDeadline(bootstrapCtx, bootstrapDeadline)
bootstrapCtx, cancel := context.WithDeadline(ctx, bootstrapDeadline)
defer cancel()

finalErr, errLog, traces := bootstrap.Bootstrap(bootstrapCtx, bootstrapClient, bootstrapConfig)
Expand Down Expand Up @@ -137,44 +137,3 @@ func run(ctx context.Context) int {

return exitCode
}

func configureLogging(logFile string, verbose bool) (*zap.Logger, error) {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "timestamp"
encoderConfig.EncodeTime = zapcore.RFC3339NanoTimeEncoder

level := zap.InfoLevel
if verbose {
level = zap.DebugLevel
}

cores := []zapcore.Core{
zapcore.NewCore(
zapcore.NewConsoleEncoder(encoderConfig),
zapcore.AddSync(os.Stdout),
level,
),
}

if logFile != "" {
if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
logFileHandle, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open log file: %w", err)
}
cores = append(cores, zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.AddSync(logFileHandle),
level,
))
}

return zap.New(zapcore.NewTee(cores...)), nil
}

func flush(logger *zap.Logger) {
// per guidance from: https://github.com/uber-go/zap/issues/328
_ = logger.Sync()
}
17 changes: 9 additions & 8 deletions client/internal/bootstrap/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
Expand Down Expand Up @@ -41,18 +42,18 @@ func extractAccessToken(token *adal.ServicePrincipalToken) (string, error) {
// getAccessToken retrieves an AAD access token (JWT) using the specified custom client ID, resource, and cloud provider config.
// MSI access tokens are retrieved from IMDS, while service principal tokens are retrieved directly from AAD.
func (c *Client) getAccessToken(ctx context.Context, customClientID, resource string, cloudProviderConfig *cloud.ProviderConfig) (string, error) {
spanName := "GetAccessToken"
tracer := telemetry.MustGetTracer(ctx)
tracer.StartSpan(spanName)
defer tracer.EndSpan(spanName)
endSpan := telemetry.StartSpan(ctx, "GetAccessToken")
defer endSpan()

logger := log.MustGetLogger(ctx)

userAssignedID := cloudProviderConfig.UserAssignedIdentityID
if customClientID != "" {
userAssignedID = customClientID
}

if userAssignedID != "" {
c.logger.Info("generating MSI access token", zap.String("clientId", userAssignedID))
logger.Info("generating MSI access token", zap.String("clientId", userAssignedID))
token, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, &adal.ManagedIdentityOptions{
ClientID: userAssignedID,
})
Expand All @@ -79,15 +80,15 @@ func (c *Client) getAccessToken(ctx context.Context, customClientID, resource st
}

if !strings.HasPrefix(cloudProviderConfig.ClientSecret, certificateSecretPrefix) {
c.logger.Info("generating SPN access token with username and password", zap.String("clientId", cloudProviderConfig.ClientID))
logger.Info("generating SPN access token with username and password", zap.String("clientId", cloudProviderConfig.ClientID))
token, err := adal.NewServicePrincipalToken(*oauthConfig, cloudProviderConfig.ClientID, cloudProviderConfig.ClientSecret, resource)
if err != nil {
return "", fmt.Errorf("generating SPN access token with username and password: %w", err)
}
return c.extractAccessTokenFunc(token)
}

c.logger.Info("client secret contains certificate data, using certificate to generate SPN access token", zap.String("clientId", cloudProviderConfig.ClientID))
logger.Info("client secret contains certificate data, using certificate to generate SPN access token", zap.String("clientId", cloudProviderConfig.ClientID))

certData, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cloudProviderConfig.ClientSecret, certificateSecretPrefix))
if err != nil {
Expand All @@ -98,7 +99,7 @@ func (c *Client) getAccessToken(ctx context.Context, customClientID, resource st
return "", fmt.Errorf("decoding pfx certificate data in client secret: %w", err)
}

c.logger.Info("generating SPN access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID))
logger.Info("generating SPN access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID))
token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, cloudProviderConfig.ClientID, certificate, privateKey, resource)
if err != nil {
return "", fmt.Errorf("generating SPN access token with certificate: %w", err)
Expand Down
6 changes: 2 additions & 4 deletions client/internal/bootstrap/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
"time"

"github.com/stretchr/testify/assert"
"go.uber.org/zap"

"github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/log"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/testutil"
"github.com/Azure/go-autorest/autorest/adal"
Expand Down Expand Up @@ -193,15 +193,13 @@ func TestGetAccessToken(t *testing.T) {
},
}

logger, _ := zap.NewDevelopment()
testTenantID := "d87a2c3e-0c0c-42b2-a883-e48cd8723e22"
testResource := "resource"

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
ctx := telemetry.NewContext()
ctx := telemetry.WithTracing(log.NewTestContext())
client := &Client{
logger: logger,
extractAccessTokenFunc: c.setupExtractAccessTokenFunc(t),
}
providerCfg := &cloud.ProviderConfig{
Expand Down
17 changes: 3 additions & 14 deletions client/internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"

"github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry"
Expand All @@ -30,7 +28,7 @@ func Bootstrap(ctx context.Context, client *Client, config *Config) (finalErr er
finalErr = retry.Do(
func() error {
defer func() {
traces.Add(telemetry.MustGetTracer(ctx).GetTrace())
traces.Add(telemetry.GetTrace(ctx))
}()

kubeconfigData, err := client.BootstrapKubeletClientCredential(ctx, config)
Expand Down Expand Up @@ -66,17 +64,8 @@ func Bootstrap(ctx context.Context, client *Client, config *Config) (finalErr er
}

func writeKubeconfig(ctx context.Context, config *clientcmdapi.Config, path string) error {
traceName := "WriteKubeconfig"
tracer := telemetry.MustGetTracer(ctx)
tracer.StartSpan(traceName)
defer tracer.EndSpan(traceName)

if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil {
return &BootstrapError{
errorType: ErrorTypeWriteKubeconfigFailure,
inner: fmt.Errorf("creating parent directories for kubeconfig path: %w", err),
}
}
endSpan := telemetry.StartSpan(ctx, "WriteKubeconfig")
defer endSpan()

if err := clientcmd.WriteToFile(*config, path); err != nil {
return &BootstrapError{
Expand Down
Loading
Loading