diff --git a/client/cmd/client/main.go b/client/cmd/client/main.go index 60417f7..983172e 100644 --- a/client/cmd/client/main.go +++ b/client/cmd/client/main.go @@ -10,14 +10,13 @@ import ( "fmt" "os" "os/signal" - "path/filepath" "syscall" "time" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/bootstrap" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry" "go.uber.org/zap" - "go.uber.org/zap/zapcore" ) var bootstrapConfig = new(bootstrap.Config) @@ -67,14 +66,16 @@ func main() { } func run(ctx context.Context) int { - logger, finalErr := configureLogging(logFile, verbose) + logger, flush, finalErr := log.NewProductionLogger(logFile, verbose) if finalErr != nil { fmt.Printf("unable to construct zap logger: %s\n", finalErr) return 1 } - defer flush(logger) + defer flush() - bootstrapClient, finalErr := bootstrap.NewClient(logger) + ctx = log.WithLogger(telemetry.WithTracing(ctx), logger) + + bootstrapClient, finalErr := bootstrap.NewClient(ctx) if finalErr != nil { fmt.Printf("unable to construct bootstrap client: %s\n", finalErr) return 1 @@ -84,8 +85,7 @@ func run(ctx context.Context) int { bootstrapDeadline := bootstrapStartTime.Add(bootstrapConfig.Deadline) logger.Info("set bootstrap deadline", zap.Time("deadline", bootstrapDeadline)) - bootstrapCtx := telemetry.WithTracer(ctx, telemetry.NewTracer()) - bootstrapCtx, cancel := context.WithDeadline(bootstrapCtx, bootstrapDeadline) + bootstrapCtx, cancel := context.WithDeadline(ctx, bootstrapDeadline) defer cancel() finalErr, errLog, traces := bootstrap.Bootstrap(bootstrapCtx, bootstrapClient, bootstrapConfig) @@ -137,44 +137,3 @@ func run(ctx context.Context) int { return exitCode } - -func configureLogging(logFile string, verbose bool) (*zap.Logger, error) { - encoderConfig := zap.NewProductionEncoderConfig() - encoderConfig.TimeKey = "timestamp" - encoderConfig.EncodeTime = zapcore.RFC3339NanoTimeEncoder - - level := zap.InfoLevel - if verbose { - level = zap.DebugLevel - } - - cores := []zapcore.Core{ - zapcore.NewCore( - zapcore.NewConsoleEncoder(encoderConfig), - zapcore.AddSync(os.Stdout), - level, - ), - } - - if logFile != "" { - if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { - return nil, fmt.Errorf("failed to create log directory: %w", err) - } - logFileHandle, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return nil, fmt.Errorf("failed to open log file: %w", err) - } - cores = append(cores, zapcore.NewCore( - zapcore.NewJSONEncoder(encoderConfig), - zapcore.AddSync(logFileHandle), - level, - )) - } - - return zap.New(zapcore.NewTee(cores...)), nil -} - -func flush(logger *zap.Logger) { - // per guidance from: https://github.com/uber-go/zap/issues/328 - _ = logger.Sync() -} diff --git a/client/internal/bootstrap/auth.go b/client/internal/bootstrap/auth.go index 46ca1a6..4fb675f 100644 --- a/client/internal/bootstrap/auth.go +++ b/client/internal/bootstrap/auth.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry" "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure" @@ -41,10 +42,10 @@ func extractAccessToken(token *adal.ServicePrincipalToken) (string, error) { // getAccessToken retrieves an AAD access token (JWT) using the specified custom client ID, resource, and cloud provider config. // MSI access tokens are retrieved from IMDS, while service principal tokens are retrieved directly from AAD. func (c *Client) getAccessToken(ctx context.Context, customClientID, resource string, cloudProviderConfig *cloud.ProviderConfig) (string, error) { - spanName := "GetAccessToken" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetAccessToken") + defer endSpan() + + logger := log.MustGetLogger(ctx) userAssignedID := cloudProviderConfig.UserAssignedIdentityID if customClientID != "" { @@ -52,7 +53,7 @@ func (c *Client) getAccessToken(ctx context.Context, customClientID, resource st } if userAssignedID != "" { - c.logger.Info("generating MSI access token", zap.String("clientId", userAssignedID)) + logger.Info("generating MSI access token", zap.String("clientId", userAssignedID)) token, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, &adal.ManagedIdentityOptions{ ClientID: userAssignedID, }) @@ -79,7 +80,7 @@ func (c *Client) getAccessToken(ctx context.Context, customClientID, resource st } if !strings.HasPrefix(cloudProviderConfig.ClientSecret, certificateSecretPrefix) { - c.logger.Info("generating SPN access token with username and password", zap.String("clientId", cloudProviderConfig.ClientID)) + logger.Info("generating SPN access token with username and password", zap.String("clientId", cloudProviderConfig.ClientID)) token, err := adal.NewServicePrincipalToken(*oauthConfig, cloudProviderConfig.ClientID, cloudProviderConfig.ClientSecret, resource) if err != nil { return "", fmt.Errorf("generating SPN access token with username and password: %w", err) @@ -87,7 +88,7 @@ func (c *Client) getAccessToken(ctx context.Context, customClientID, resource st return c.extractAccessTokenFunc(token) } - c.logger.Info("client secret contains certificate data, using certificate to generate SPN access token", zap.String("clientId", cloudProviderConfig.ClientID)) + logger.Info("client secret contains certificate data, using certificate to generate SPN access token", zap.String("clientId", cloudProviderConfig.ClientID)) certData, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cloudProviderConfig.ClientSecret, certificateSecretPrefix)) if err != nil { @@ -98,7 +99,7 @@ func (c *Client) getAccessToken(ctx context.Context, customClientID, resource st return "", fmt.Errorf("decoding pfx certificate data in client secret: %w", err) } - c.logger.Info("generating SPN access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID)) + logger.Info("generating SPN access token with certificate", zap.String("clientId", cloudProviderConfig.ClientID)) token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, cloudProviderConfig.ClientID, certificate, privateKey, resource) if err != nil { return "", fmt.Errorf("generating SPN access token with certificate: %w", err) diff --git a/client/internal/bootstrap/auth_test.go b/client/internal/bootstrap/auth_test.go index 9e4fb1a..13409d0 100644 --- a/client/internal/bootstrap/auth_test.go +++ b/client/internal/bootstrap/auth_test.go @@ -9,9 +9,9 @@ import ( "time" "github.com/stretchr/testify/assert" - "go.uber.org/zap" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/cloud" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/testutil" "github.com/Azure/go-autorest/autorest/adal" @@ -193,15 +193,13 @@ func TestGetAccessToken(t *testing.T) { }, } - logger, _ := zap.NewDevelopment() testTenantID := "d87a2c3e-0c0c-42b2-a883-e48cd8723e22" testResource := "resource" for _, c := range cases { t.Run(c.name, func(t *testing.T) { - ctx := telemetry.NewContext() + ctx := telemetry.WithTracing(log.NewTestContext()) client := &Client{ - logger: logger, extractAccessTokenFunc: c.setupExtractAccessTokenFunc(t), } providerCfg := &cloud.ProviderConfig{ diff --git a/client/internal/bootstrap/bootstrap.go b/client/internal/bootstrap/bootstrap.go index 02fa2a7..6a6b4a9 100644 --- a/client/internal/bootstrap/bootstrap.go +++ b/client/internal/bootstrap/bootstrap.go @@ -7,8 +7,6 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" "time" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry" @@ -30,7 +28,7 @@ func Bootstrap(ctx context.Context, client *Client, config *Config) (finalErr er finalErr = retry.Do( func() error { defer func() { - traces.Add(telemetry.MustGetTracer(ctx).GetTrace()) + traces.Add(telemetry.GetTrace(ctx)) }() kubeconfigData, err := client.BootstrapKubeletClientCredential(ctx, config) @@ -66,17 +64,8 @@ func Bootstrap(ctx context.Context, client *Client, config *Config) (finalErr er } func writeKubeconfig(ctx context.Context, config *clientcmdapi.Config, path string) error { - traceName := "WriteKubeconfig" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(traceName) - defer tracer.EndSpan(traceName) - - if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil { - return &BootstrapError{ - errorType: ErrorTypeWriteKubeconfigFailure, - inner: fmt.Errorf("creating parent directories for kubeconfig path: %w", err), - } - } + endSpan := telemetry.StartSpan(ctx, "WriteKubeconfig") + defer endSpan() if err := clientcmd.WriteToFile(*config, path); err != nil { return &BootstrapError{ diff --git a/client/internal/bootstrap/client.go b/client/internal/bootstrap/client.go index 8f8653c..87a2245 100644 --- a/client/internal/bootstrap/client.go +++ b/client/internal/bootstrap/client.go @@ -10,6 +10,7 @@ import ( "github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/kubeconfig" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry" akssecuretlsbootstrapv1 "github.com/Azure/aks-secure-tls-bootstrap/service/pkg/gen/akssecuretlsbootstrap/v1" "go.uber.org/zap" @@ -17,44 +18,44 @@ import ( ) type Client struct { - logger *zap.Logger imdsClient imds.Client kubeconfigValidator kubeconfig.Validator getServiceClientFunc getServiceClientFunc extractAccessTokenFunc extractAccessTokenFunc } -func NewClient(logger *zap.Logger) (*Client, error) { +func NewClient(ctx context.Context) (*Client, error) { return &Client{ - logger: logger, - imdsClient: imds.NewClient(logger), - kubeconfigValidator: kubeconfig.NewValidator(logger), + imdsClient: imds.NewClient(ctx), + kubeconfigValidator: kubeconfig.NewValidator(), getServiceClientFunc: getServiceClient, extractAccessTokenFunc: extractAccessToken, }, nil } func (c *Client) BootstrapKubeletClientCredential(ctx context.Context, cfg *Config) (*clientcmdapi.Config, error) { + logger := log.MustGetLogger(ctx) + err := c.validateKubeconfig(ctx, cfg.KubeconfigPath, cfg.EnsureAuthorizedClient) if err == nil { - c.logger.Info("existing kubeconfig is valid, will skip bootstrapping", zap.String("kubeconfig", cfg.KubeconfigPath)) + logger.Info("existing kubeconfig is valid, will skip bootstrapping", zap.String("kubeconfig", cfg.KubeconfigPath)) return nil, nil } - c.logger.Info("failed to validate existing kubeconfig, will bootstrap a new client credential", zap.String("kubeconfig", cfg.KubeconfigPath), zap.Error(err)) + logger.Info("failed to validate existing kubeconfig, will bootstrap a new client credential", zap.String("kubeconfig", cfg.KubeconfigPath), zap.Error(err)) token, err := c.getAccessToken(ctx, cfg.CustomClientID, cfg.AADResource, &cfg.ProviderConfig) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetAccessTokenFailure, inner: err, } } - c.logger.Info("generated access token for gRPC connection") + logger.Info("generated access token for gRPC connection") serviceClient, closer, err := c.getServiceClient(ctx, token, cfg) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetServiceClientFailure, inner: err, @@ -62,52 +63,52 @@ func (c *Client) BootstrapKubeletClientCredential(ctx context.Context, cfg *Conf } defer func() { if err := closer(); err != nil { - c.logger.Error("failed to close gRPC client connection", zap.Error(err)) + logger.Error("failed to close gRPC client connection", zap.Error(err)) } }() - c.logger.Info("created bootstrap service gRPC client") + logger.Info("created bootstrap service gRPC client") instanceData, err := c.getInstanceData(ctx) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetInstanceDataFailure, inner: err, } } - c.logger.Info("retrieved instance metadata from IMDS", zap.String("resourceId", instanceData.Compute.ResourceID)) + logger.Info("retrieved instance metadata from IMDS", zap.String("resourceId", instanceData.Compute.ResourceID)) nonce, err := c.getNonce(ctx, serviceClient, &akssecuretlsbootstrapv1.GetNonceRequest{ ResourceId: instanceData.Compute.ResourceID, }) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetNonceFailure, inner: err, } } - c.logger.Info("received new nonce from bootstrap server") + logger.Info("received new nonce from bootstrap server") attestedData, err := c.getAttestedData(ctx, nonce) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetAttestedDataFailure, inner: err, } } - c.logger.Info("retrieved instance attested data from IMDS") + logger.Info("retrieved instance attested data from IMDS") csrPEM, keyPEM, err := c.getCSR(ctx) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetCSRFailure, inner: err, } } - c.logger.Info("generated kubelet client CSR and associated private key") + logger.Info("generated kubelet client CSR and associated private key") certPEM, err := c.getCredential(ctx, serviceClient, &akssecuretlsbootstrapv1.GetCredentialRequest{ ResourceId: instanceData.Compute.ResourceID, @@ -116,13 +117,13 @@ func (c *Client) BootstrapKubeletClientCredential(ctx context.Context, cfg *Conf EncodedCsrPem: base64.StdEncoding.EncodeToString(csrPEM), }) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGetCredentialFailure, inner: err, } } - c.logger.Info("received valid kubelet client credential from bootstrap server") + logger.Info("received valid kubelet client credential from bootstrap server") kubeconfigData, err := c.generateKubeconfig(ctx, certPEM, keyPEM, &kubeconfig.Config{ APIServerFQDN: cfg.APIServerFQDN, @@ -130,34 +131,30 @@ func (c *Client) BootstrapKubeletClientCredential(ctx context.Context, cfg *Conf CredFilePath: cfg.CredFilePath, }) if err != nil { - c.logger.Error(err.Error()) + logger.Error(err.Error()) return nil, &BootstrapError{ errorType: ErrorTypeGenerateKubeconfigFailure, inner: err, } } - c.logger.Info("successfully generated new kubeconfig data") + logger.Info("successfully generated new kubeconfig data") return kubeconfigData, nil } func (c *Client) validateKubeconfig(ctx context.Context, kubeconfigPath string, ensureAuthorizedClient bool) error { - spanName := "ValidateKubeconfig" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "ValidateKubeconfig") + defer endSpan() - if err := c.kubeconfigValidator.Validate(kubeconfigPath, ensureAuthorizedClient); err != nil { + if err := c.kubeconfigValidator.Validate(ctx, kubeconfigPath, ensureAuthorizedClient); err != nil { return fmt.Errorf("failed to validate kubeconfig: %w", err) } return nil } func (c *Client) getServiceClient(ctx context.Context, token string, cfg *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, closeFunc, error) { - spanName := "GetServiceClient" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetServiceClient") + defer endSpan() serviceClient, closer, err := c.getServiceClientFunc(token, cfg) if err != nil { @@ -168,10 +165,8 @@ func (c *Client) getServiceClient(ctx context.Context, token string, cfg *Config } func (c *Client) getInstanceData(ctx context.Context) (*imds.VMInstanceData, error) { - spanName := "GetInstanceData" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetInstanceData") + defer endSpan() instanceData, err := c.imdsClient.GetInstanceData(ctx) if err != nil { @@ -181,10 +176,8 @@ func (c *Client) getInstanceData(ctx context.Context) (*imds.VMInstanceData, err } func (c *Client) getAttestedData(ctx context.Context, nonce string) (*imds.VMAttestedData, error) { - spanName := "GetAttestedData" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetAttestedData") + defer endSpan() attestedData, err := c.imdsClient.GetAttestedData(ctx, nonce) if err != nil { @@ -197,23 +190,20 @@ func (c *Client) getNonce( ctx context.Context, serviceClient akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, req *akssecuretlsbootstrapv1.GetNonceRequest) (string, error) { - spanName := "GetNonce" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetNonce") + defer endSpan() nonceResponse, err := serviceClient.GetNonce(ctx, req) if err != nil { + err = withLastGRPCRetryErrorIfDeadlineExceeded(err) return "", fmt.Errorf("failed to retrieve a nonce from bootstrap server: %w", err) } return nonceResponse.GetNonce(), nil } func (c *Client) getCSR(ctx context.Context) ([]byte, []byte, error) { - spanName := "GetCSR" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetCSR") + defer endSpan() csrPEM, keyPEM, err := makeKubeletClientCSR() if err != nil { @@ -226,16 +216,14 @@ func (c *Client) getCredential( ctx context.Context, serviceClient akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, req *akssecuretlsbootstrapv1.GetCredentialRequest) ([]byte, error) { - spanName := "GetCredential" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GetCredential") + defer endSpan() credentialResponse, err := serviceClient.GetCredential(ctx, req) if err != nil { + err = withLastGRPCRetryErrorIfDeadlineExceeded(err) return nil, fmt.Errorf("failed to retrieve new kubelet client credential from bootstrap server: %w", err) } - c.logger.Info("received credential response from bootstrap server") encodedCertPEM := credentialResponse.GetEncodedCertPem() if encodedCertPEM == "" { @@ -249,10 +237,8 @@ func (c *Client) getCredential( } func (c *Client) generateKubeconfig(ctx context.Context, certPEM, keyPEM []byte, cfg *kubeconfig.Config) (*clientcmdapi.Config, error) { - spanName := "GenerateKubeconfig" - tracer := telemetry.MustGetTracer(ctx) - tracer.StartSpan(spanName) - defer tracer.EndSpan(spanName) + endSpan := telemetry.StartSpan(ctx, "GenerateKubeconfig") + defer endSpan() kubeconfigData, err := kubeconfig.GenerateForCertAndKey(certPEM, keyPEM, cfg) if err != nil { diff --git a/client/internal/bootstrap/client_test.go b/client/internal/bootstrap/client_test.go index 06ab88b..607c62d 100644 --- a/client/internal/bootstrap/client_test.go +++ b/client/internal/bootstrap/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds" imdsmocks "github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds/mocks" kubeconfigmocks "github.com/Azure/aks-secure-tls-bootstrap/client/internal/kubeconfig/mocks" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/telemetry" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/testutil" akssecuretlsbootstrapv1 "github.com/Azure/aks-secure-tls-bootstrap/service/pkg/gen/akssecuretlsbootstrap/v1" @@ -27,7 +28,6 @@ import ( "github.com/Azure/go-autorest/autorest/azure" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" - "go.uber.org/zap" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" ) @@ -43,7 +43,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when specified kubeconfig is already valid", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false).Return(nil).Times(1) + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false).Return(nil).Times(1) serviceClient.EXPECT().GetCredential(gomock.Any(), gomock.Any()).Times(0) serviceClient.EXPECT().GetNonce(gomock.Any(), gomock.Any()).Times(0) imdsClient.EXPECT().GetAttestedData(gomock.Any(), gomock.Any()).Times(0) @@ -59,7 +59,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { bootstrapConfig.ProviderConfig.ClientSecret = "" // force access token failure - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) }, expectedError: &BootstrapError{ @@ -74,7 +74,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when unable to retrieve instance data from IMDS", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(nil, errors.New("cannot get VM instance data from IMDS")).Times(1) @@ -91,7 +91,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when unable to retrieve nonce from bootstrap server", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(&imds.VMInstanceData{}, nil).Times(1) @@ -110,7 +110,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when unable to retrieve attested data from IMDS", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(&imds.VMInstanceData{}, nil).Times(1) @@ -131,7 +131,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when unable to retrieve a credential from the bootstrap server", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(&imds.VMInstanceData{}, nil).Times(1) @@ -154,7 +154,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when bootstrap server responds with an empty credential", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(&imds.VMInstanceData{}, nil).Times(1) @@ -177,7 +177,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { name: "when bootstrap server responds with an invalid credential", setupMocks: func(ctx context.Context, bootstrapConfig *Config, imdsClient *imdsmocks.MockClient, kubeconfigValidator *kubeconfigmocks.MockValidator, serviceClient *akssecuretlsbootstrapv1_mocks.MockSecureTLSBootstrapServiceClient, _ *[]byte) { - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(&imds.VMInstanceData{}, nil).Times(1) @@ -209,7 +209,7 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { }) assert.NoError(t, err) clientCertBlock, _ := pem.Decode(clientCertPEM) - kubeconfigValidator.EXPECT().Validate(bootstrapConfig.KubeconfigPath, false). + kubeconfigValidator.EXPECT().Validate(ctx, bootstrapConfig.KubeconfigPath, false). Return(fmt.Errorf("invalid kubeconfig")).Times(1) imdsClient.EXPECT().GetInstanceData(ctx). Return(&imds.VMInstanceData{}, nil).Times(1) @@ -254,19 +254,18 @@ func TestBootstrapKubeletClientCredential(t *testing.T) { }, } - logger, _ := zap.NewDevelopment() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - ctx := telemetry.NewContext() + ctx := telemetry.WithTracing(log.NewTestContext()) + imdsClient := imdsmocks.NewMockClient(mockCtrl) kubeconfigValidator := kubeconfigmocks.NewMockValidator(mockCtrl) serviceClient := akssecuretlsbootstrapv1_mocks.NewMockSecureTLSBootstrapServiceClient(mockCtrl) client := &Client{ - logger: logger, imdsClient: imdsClient, kubeconfigValidator: kubeconfigValidator, getServiceClientFunc: func(_ string, _ *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, closeFunc, error) { diff --git a/client/internal/bootstrap/grpc.go b/client/internal/bootstrap/grpc.go index 066b4da..7b6b7bd 100644 --- a/client/internal/bootstrap/grpc.go +++ b/client/internal/bootstrap/grpc.go @@ -4,6 +4,7 @@ package bootstrap import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -11,15 +12,22 @@ import ( "time" internalhttp "github.com/Azure/aks-secure-tls-bootstrap/client/internal/http" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" akssecuretlsbootstrapv1 "github.com/Azure/aks-secure-tls-bootstrap/service/pkg/gen/akssecuretlsbootstrap/v1" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" + "go.uber.org/zap" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/oauth" + "google.golang.org/grpc/status" ) +// used to store any errors encountered by the gRPC client when making RPCs to the remote +// within the retry loop configured by retry.UnaryClientInterceptor +var lastGRPCRetryError error + // closeFunc closes a gRPC client connection, fake implementations given in unit tests. type closeFunc func() error @@ -39,7 +47,7 @@ func getServiceClient(token string, cfg *Config) (akssecuretlsbootstrapv1.Secure conn, err := grpc.NewClient( fmt.Sprintf("%s:443", cfg.APIServerFQDN), - grpc.WithUserAgent(internalhttp.GetUserAgentValue()), + grpc.WithUserAgent(internalhttp.UserAgent()), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithPerRPCCredentials(oauth.TokenSource{ TokenSource: oauth2.StaticTokenSource(&oauth2.Token{ @@ -47,6 +55,7 @@ func getServiceClient(token string, cfg *Config) (akssecuretlsbootstrapv1.Secure }), }), grpc.WithUnaryInterceptor(retry.UnaryClientInterceptor( + retry.WithOnRetryCallback(getGRPCOnRetryCallbackFunc()), retry.WithBackoff(retry.BackoffExponentialWithJitterBounded(100*time.Millisecond, 0.75, 2*time.Second)), retry.WithCodes(codes.Aborted, codes.Unavailable), retry.WithMax(30), @@ -59,6 +68,26 @@ func getServiceClient(token string, cfg *Config) (akssecuretlsbootstrapv1.Secure return akssecuretlsbootstrapv1.NewSecureTLSBootstrapServiceClient(conn), conn.Close, nil } +func getGRPCOnRetryCallbackFunc() retry.OnRetryCallback { + // this function is called after every retry attempt assuming the attempt failed, + // and the failure was not caused by a context error (e.g. DeadlineExceeded or Cancelled), + // see: https://github.com/grpc-ecosystem/go-grpc-middleware/blob/main/interceptors/retry/retry.go. + // the error is logged and stored within lastGRPCRetryError. + return func(ctx context.Context, attempt uint, err error) { + log.MustGetLogger(ctx).Error("gRPC request failed", zap.Error(err), zap.Uint("attempt", attempt)) + lastGRPCRetryError = err + } +} + +// withLastGRPCRetryErrorIfDeadlineExceeded wraps the error with lastGRPCRetryError if the error +// was caused by a context.DeadlineExceeded. +func withLastGRPCRetryErrorIfDeadlineExceeded(err error) error { + if lastGRPCRetryError == nil || status.Code(err) != codes.DeadlineExceeded { + return err + } + return fmt.Errorf("%w: last error: %s", err, lastGRPCRetryError) +} + func getTLSConfig(caPEM []byte, nextProto string, insecureSkipVerify bool) (*tls.Config, error) { roots := x509.NewCertPool() if ok := roots.AppendCertsFromPEM(caPEM); !ok { diff --git a/client/internal/bootstrap/grpc_test.go b/client/internal/bootstrap/grpc_test.go index db173fd..55ad624 100644 --- a/client/internal/bootstrap/grpc_test.go +++ b/client/internal/bootstrap/grpc_test.go @@ -5,13 +5,18 @@ package bootstrap import ( "crypto/x509" + "errors" + "fmt" "os" "path/filepath" "testing" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/testutil" "github.com/stretchr/testify/assert" ) @@ -157,3 +162,56 @@ func TestGetTLSConfig(t *testing.T) { }) } } + +func TestGetGRPCOnRetryCallbackFunc(t *testing.T) { + t.Cleanup(func() { + lastGRPCRetryError = nil + }) + ctx := log.NewTestContext() + errs := []error{errors.New("e0"), errors.New("e1"), errors.New("e2")} + + fn := getGRPCOnRetryCallbackFunc() + for idx, err := range errs { + fn(ctx, uint(idx+1), err) + } + assert.Equal(t, errs[len(errs)-1], lastGRPCRetryError) +} + +func TestWithLastGRPCRetryErrorIfDeadlineExceeded(t *testing.T) { + cases := []struct { + name string + err error + lastGRPCRetryError error + expectedErr error + }{ + { + name: "last GRPC retry error is nil", + err: errors.New("non-retryable error"), + lastGRPCRetryError: nil, + expectedErr: errors.New("non-retryable error"), + }, + { + name: "err is not a context.DeadlineExceeded", + err: errors.New("an error"), + lastGRPCRetryError: errors.New("service unavailable"), + expectedErr: errors.New("an error"), + }, + { + name: "err is a context.DeadlineExceeded and last GRPC retry error is non-nil", + err: status.Error(codes.DeadlineExceeded, "context deadline exceeded"), + lastGRPCRetryError: errors.New("service unavailable"), + expectedErr: fmt.Errorf("%w: last error: %s", status.Error(codes.DeadlineExceeded, "context deadline exceeded"), errors.New("service unavailable")), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + t.Cleanup(func() { + lastGRPCRetryError = nil + }) + lastGRPCRetryError = c.lastGRPCRetryError + + assert.Equal(t, c.expectedErr, withLastGRPCRetryErrorIfDeadlineExceeded(c.err)) + }) + } +} diff --git a/client/internal/http/http.go b/client/internal/http/http.go index d3836ce..efc8dc1 100644 --- a/client/internal/http/http.go +++ b/client/internal/http/http.go @@ -4,43 +4,58 @@ package http import ( + "context" "fmt" "net/http" "time" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/build" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/hashicorp/go-retryablehttp" - "go.uber.org/zap" ) const ( userAgentHeaderKey = "User-Agent" ) -// GetUserAgentValue returns the common User-Agent header value used in all RPCs and HTTP calls. -func GetUserAgentValue() string { +// UserAgent returns the common User-Agent header value used in all RPCs and HTTP calls. +func UserAgent() string { return fmt.Sprintf("aks-secure-tls-bootstrap-client/%s", build.GetVersion()) } -// NewClient returns an http.Client shimed into a *retryablehttp.Client with a custom transport. -func NewClient(logger *zap.Logger) *http.Client { - return NewRetryableClient(logger).StandardClient() +// NewClient returns an http.Client shimmed into a *retryablehttp.Client with a custom transport. +func NewClient(ctx context.Context) *http.Client { + return NewRetryableClient(ctx).StandardClient() } // NewRetryableClient returns a *retryablehttp.Client with a custom transport. -func NewRetryableClient(logger *zap.Logger) *retryablehttp.Client { - c := retryablehttp.NewClient() - c.RetryMax = 5 - c.RetryWaitMin = 300 * time.Millisecond - c.RetryWaitMax = 3 * time.Second - transport := c.HTTPClient.Transport - c.HTTPClient.Transport = &customTransport{ +func NewRetryableClient(ctx context.Context) *retryablehttp.Client { + client := retryablehttp.NewClient() + configureLogger(ctx, client) + configureRetryPolicy(client) + configureTransport(client) + return client +} + +func configureLogger(ctx context.Context, client *retryablehttp.Client) { + client.Logger = log.NewLeveledLoggerShim(log.MustGetLogger(ctx)) +} + +func configureRetryPolicy(client *retryablehttp.Client) { + // retryablehttp.DefaultBackoff implements an exponential backoff strategy + // bounded by RetryWaitMin + RetryWaitMax. It will also attempt to parse out and respect any + // Retry-After header from the server's response. + client.Backoff = retryablehttp.DefaultBackoff + client.RetryMax = 5 + client.RetryWaitMin = 300 * time.Millisecond + client.RetryWaitMax = 3 * time.Second +} + +func configureTransport(client *retryablehttp.Client) { + transport := client.HTTPClient.Transport + client.HTTPClient.Transport = &customTransport{ base: transport, } - c.Logger = &leveledLoggerShim{ - logger: logger, - } - return c } type customTransport struct { @@ -48,6 +63,6 @@ type customTransport struct { } func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set(userAgentHeaderKey, GetUserAgentValue()) + req.Header.Set(userAgentHeaderKey, UserAgent()) return t.base.RoundTrip(req) } diff --git a/client/internal/imds/imds.go b/client/internal/imds/imds.go index 822ba73..b27ec44 100644 --- a/client/internal/imds/imds.go +++ b/client/internal/imds/imds.go @@ -11,6 +11,7 @@ import ( "net/http" internalhttp "github.com/Azure/aks-secure-tls-bootstrap/client/internal/http" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "go.uber.org/zap" ) @@ -24,22 +25,20 @@ type Client interface { type client struct { baseURL string httpClient *http.Client - logger *zap.Logger } var _ Client = (*client)(nil) -func NewClient(logger *zap.Logger) Client { +func NewClient(ctx context.Context) Client { return &client{ baseURL: imdsURL, - httpClient: internalhttp.NewClient(logger), - logger: logger, + httpClient: internalhttp.NewClient(ctx), } } func (c *client) GetInstanceData(ctx context.Context) (*VMInstanceData, error) { url := fmt.Sprintf("%s/%s", c.baseURL, instanceDataEndpoint) - c.logger.Info("calling IMDS instance data endpoint", zap.String("url", url)) + log.MustGetLogger(ctx).Info("calling IMDS instance data endpoint", zap.String("url", url)) params := getCommonParameters() @@ -53,7 +52,7 @@ func (c *client) GetInstanceData(ctx context.Context) (*VMInstanceData, error) { func (c *client) GetAttestedData(ctx context.Context, nonce string) (*VMAttestedData, error) { url := fmt.Sprintf("%s/%s", c.baseURL, attestedDataEndpoint) - c.logger.Info("calling IMDS attested data endpoint", zap.String("url", url)) + log.MustGetLogger(ctx).Info("calling IMDS attested data endpoint", zap.String("url", url)) params := getCommonParameters() params[nonceParameterKey] = nonce diff --git a/client/internal/imds/imds_test.go b/client/internal/imds/imds_test.go index 05efae1..77ef7fb 100644 --- a/client/internal/imds/imds_test.go +++ b/client/internal/imds/imds_test.go @@ -4,7 +4,6 @@ package imds import ( - "context" "fmt" "net/http" "net/http/httptest" @@ -12,8 +11,8 @@ import ( "testing" internalhttp "github.com/Azure/aks-secure-tls-bootstrap/client/internal/http" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/stretchr/testify/assert" - "go.uber.org/zap" ) func TestCallIMDS(t *testing.T) { @@ -60,12 +59,11 @@ func TestCallIMDS(t *testing.T) { }, } - logger, _ := zap.NewDevelopment() + ctx := log.NewTestContext() + for _, tt := range tests { - ctx := context.Background() imdsClient := &client{ - httpClient: internalhttp.NewClient(logger), - logger: logger, + httpClient: internalhttp.NewClient(ctx), } imds := tt.setupTestServer(tt.params) defer imds.Close() @@ -101,8 +99,7 @@ func TestGetInstanceData(t *testing.T) { }, } - logger, _ := zap.NewDevelopment() - ctx := context.Background() + ctx := log.NewTestContext() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -115,8 +112,7 @@ func TestGetInstanceData(t *testing.T) { defer imds.Close() imdsClient := &client{ - httpClient: internalhttp.NewClient(logger), - logger: logger, + httpClient: internalhttp.NewClient(ctx), baseURL: imds.URL, } @@ -160,8 +156,7 @@ func TestGetAttestedData(t *testing.T) { }, } - logger, _ := zap.NewDevelopment() - ctx := context.Background() + ctx := log.NewTestContext() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -175,8 +170,7 @@ func TestGetAttestedData(t *testing.T) { defer imds.Close() imdsClient := &client{ - httpClient: internalhttp.NewClient(logger), - logger: logger, + httpClient: internalhttp.NewClient(ctx), baseURL: imds.URL, } diff --git a/client/internal/kubeconfig/kubeconfig.go b/client/internal/kubeconfig/kubeconfig.go index d20585e..9d96cde 100644 --- a/client/internal/kubeconfig/kubeconfig.go +++ b/client/internal/kubeconfig/kubeconfig.go @@ -30,7 +30,7 @@ func GenerateForCertAndKey(certPEM, keyPEM []byte, cfg *Config) (*clientcmdapi.C if _, err := credBytes.Write(keyPEM); err != nil { return nil, fmt.Errorf("writing client key PEM bytes to buffer: %w", err) } - if err := os.MkdirAll(filepath.Dir(cfg.CredFilePath), 0600); err != nil { + if err := os.MkdirAll(filepath.Dir(cfg.CredFilePath), 0755); err != nil { return nil, fmt.Errorf("creating parent directories for cred file path: %w", err) } if err := os.WriteFile(cfg.CredFilePath, credBytes.Bytes(), 0600); err != nil { diff --git a/client/internal/kubeconfig/mocks/mock_validator.go b/client/internal/kubeconfig/mocks/mock_validator.go index 5c5e939..b77cf62 100644 --- a/client/internal/kubeconfig/mocks/mock_validator.go +++ b/client/internal/kubeconfig/mocks/mock_validator.go @@ -13,6 +13,7 @@ package mocks import ( + context "context" reflect "reflect" gomock "go.uber.org/mock/gomock" @@ -43,15 +44,15 @@ func (m *MockValidator) EXPECT() *MockValidatorMockRecorder { } // Validate mocks base method. -func (m *MockValidator) Validate(kubeconfigPath string, ensureAuthorizedClient bool) error { +func (m *MockValidator) Validate(ctx context.Context, kubeconfigPath string, ensureAuthorizedClient bool) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Validate", kubeconfigPath, ensureAuthorizedClient) + ret := m.ctrl.Call(m, "Validate", ctx, kubeconfigPath, ensureAuthorizedClient) ret0, _ := ret[0].(error) return ret0 } // Validate indicates an expected call of Validate. -func (mr *MockValidatorMockRecorder) Validate(kubeconfigPath, ensureAuthorizedClient any) *gomock.Call { +func (mr *MockValidatorMockRecorder) Validate(ctx, kubeconfigPath, ensureAuthorizedClient any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockValidator)(nil).Validate), kubeconfigPath, ensureAuthorizedClient) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockValidator)(nil).Validate), ctx, kubeconfigPath, ensureAuthorizedClient) } diff --git a/client/internal/kubeconfig/validator.go b/client/internal/kubeconfig/validator.go index 4892ae0..6ed70a2 100644 --- a/client/internal/kubeconfig/validator.go +++ b/client/internal/kubeconfig/validator.go @@ -4,12 +4,14 @@ package kubeconfig import ( + "context" "fmt" "net/http" "os" "time" internalhttp "github.com/Azure/aks-secure-tls-bootstrap/client/internal/http" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/hashicorp/go-retryablehttp" "go.uber.org/zap" "k8s.io/apimachinery/pkg/api/errors" @@ -31,18 +33,17 @@ type clientsetLoaderFunc func(clientConfig *restclient.Config) (kubernetes.Inter //go:generate ../../bin/mockgen -copyright_file=../../../hack/copyright_header.txt -destination=./mocks/mock_validator.go -package=mocks github.com/Azure/aks-secure-tls-bootstrap/client/internal/kubeconfig Validator type Validator interface { - Validate(kubeconfigPath string, ensureAuthorizedClient bool) error + Validate(ctx context.Context, kubeconfigPath string, ensureAuthorizedClient bool) error } type validator struct { clientConfigLoader clientConfigLoaderFunc clientsetLoader clientsetLoaderFunc - logger *zap.Logger } var _ Validator = (*validator)(nil) -func NewValidator(logger *zap.Logger) Validator { +func NewValidator() Validator { return &validator{ clientConfigLoader: func(kubeconfigPath string) (*restclient.Config, error) { if _, err := os.Stat(kubeconfigPath); err != nil { @@ -65,11 +66,10 @@ func NewValidator(logger *zap.Logger) Validator { clientsetLoader: func(clientConfig *restclient.Config) (kubernetes.Interface, error) { return kubernetes.NewForConfig(clientConfig) }, - logger: logger, } } -func (v *validator) Validate(kubeconfigPath string, ensureAuthorizedClient bool) error { +func (v *validator) Validate(ctx context.Context, kubeconfigPath string, ensureAuthorizedClient bool) error { clientConfig, err := v.clientConfigLoader(kubeconfigPath) if err != nil { return fmt.Errorf("failed to create REST client config from kubeconfig: %w", err) @@ -80,9 +80,9 @@ func (v *validator) Validate(kubeconfigPath string, ensureAuthorizedClient bool) if !ensureAuthorizedClient { return nil } - restclient.AddUserAgent(clientConfig, internalhttp.GetUserAgentValue()) + restclient.AddUserAgent(clientConfig, internalhttp.UserAgent()) clientConfig.Wrap(func(rt http.RoundTripper) http.RoundTripper { - c := internalhttp.NewRetryableClient(v.logger) + c := internalhttp.NewRetryableClient(ctx) c.HTTPClient = &http.Client{Transport: rt} return &retryablehttp.RoundTripper{Client: c} }) @@ -93,7 +93,7 @@ func (v *validator) Validate(kubeconfigPath string, ensureAuthorizedClient bool) if err := ensureAuthorized(clientset); err != nil { return fmt.Errorf("failed to ensure client authorization: %w", err) } - v.logger.Info("ensured existing clientset is authorized", zap.String("kubeconfig", kubeconfigPath)) + log.MustGetLogger(ctx).Info("ensured existing clientset is authorized", zap.String("kubeconfig", kubeconfigPath)) return nil } diff --git a/client/internal/kubeconfig/validator_test.go b/client/internal/kubeconfig/validator_test.go index 91173b2..d416d21 100644 --- a/client/internal/kubeconfig/validator_test.go +++ b/client/internal/kubeconfig/validator_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" + "github.com/Azure/aks-secure-tls-bootstrap/client/internal/log" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/testutil" "github.com/stretchr/testify/assert" - "go.uber.org/zap" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" fakediscovery "k8s.io/client-go/discovery/fake" @@ -21,9 +21,7 @@ import ( ) func TestNewValidator(t *testing.T) { - logger, _ := zap.NewDevelopment() - - v := NewValidator(logger) + v := NewValidator() assert.NotNil(t, v) vv, ok := v.(*validator) @@ -33,8 +31,6 @@ func TestNewValidator(t *testing.T) { } func TestValidateKubeconfig(t *testing.T) { - logger, _ := zap.NewDevelopment() - validCertPEM, validKeyPEM, err := testutil.GenerateCertPEM(testutil.CertTemplate{ CommonName: "cn", Organization: "org", @@ -134,11 +130,11 @@ func TestValidateKubeconfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ - logger: logger, - } + ctx := log.NewTestContext() + v := new(validator) tt.setupFunc(v) - err := v.Validate("path", false) + + err := v.Validate(ctx, "path", false) if len(tt.expectedErrs) > 0 { assert.Error(t, err) for _, expectedErr := range tt.expectedErrs { @@ -209,13 +205,10 @@ func TestEnsureAuthorizedClient(t *testing.T) { }) assert.NoError(t, err) - logger, _ := zap.NewDevelopment() - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - v := &validator{ - logger: logger, - } + ctx := log.NewTestContext() + v := new(validator) clientset := fake.NewSimpleClientset() v.clientConfigLoader = func(kubeconfigPath string) (*restclient.Config, error) { @@ -233,7 +226,7 @@ func TestEnsureAuthorizedClient(t *testing.T) { } tt.setupFunc(v, clientset) - err := v.Validate("path", true) + err := v.Validate(ctx, "path", true) if len(tt.expectedErrs) > 0 { assert.Error(t, err) diff --git a/client/internal/log/log.go b/client/internal/log/log.go new file mode 100644 index 0000000..89a50f6 --- /dev/null +++ b/client/internal/log/log.go @@ -0,0 +1,75 @@ +package log + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type contextKey struct{} + +type flushFunc func() + +func NewProductionLogger(logFile string, verbose bool) (*zap.Logger, flushFunc, error) { + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "timestamp" + encoderConfig.EncodeTime = zapcore.RFC3339NanoTimeEncoder + + level := zap.InfoLevel + if verbose { + level = zap.DebugLevel + } + + cores := []zapcore.Core{ + zapcore.NewCore( + zapcore.NewConsoleEncoder(encoderConfig), + zapcore.AddSync(os.Stdout), + level, + ), + } + + if logFile != "" { + if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { + return nil, nil, fmt.Errorf("failed to create log directory: %w", err) + } + logFileHandle, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, nil, fmt.Errorf("failed to open log file: %w", err) + } + cores = append(cores, zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + zapcore.AddSync(logFileHandle), + level, + )) + } + + logger := zap.New(zapcore.NewTee(cores...)) + + flush := func() { + // per guidance from: https://github.com/uber-go/zap/issues/328 + _ = logger.Sync() + } + + return logger, flush, nil +} + +func WithLogger(ctx context.Context, logger *zap.Logger) context.Context { + return context.WithValue(ctx, contextKey{}, logger) +} + +func MustGetLogger(ctx context.Context) *zap.Logger { + logger, ok := ctx.Value(contextKey{}).(*zap.Logger) + if !ok { + panic("logger not found on context") + } + return logger +} + +func NewTestContext() context.Context { + logger, _ := zap.NewDevelopment() + return WithLogger(context.Background(), logger) +} diff --git a/client/internal/http/log.go b/client/internal/log/shims.go similarity index 56% rename from client/internal/http/log.go rename to client/internal/log/shims.go index 534e2c6..b33c020 100644 --- a/client/internal/http/log.go +++ b/client/internal/log/shims.go @@ -1,7 +1,4 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package http +package log import ( "github.com/hashicorp/go-retryablehttp" @@ -9,30 +6,36 @@ import ( "go.uber.org/zap/zapcore" ) -var _ retryablehttp.LeveledLogger = (*leveledLoggerShim)(nil) +var _ retryablehttp.LeveledLogger = (*LeveledLoggerShim)(nil) -// leveledLoggerShim provides an implementation of retryablehttp.LeveledLogger, shimming into a zap.Logger. -type leveledLoggerShim struct { +// LeveledLoggerShim provides an implementation of retryablehttp.LeveledLogger, shimming into a zap.Logger. +type LeveledLoggerShim struct { logger *zap.Logger } -func (l *leveledLoggerShim) Debug(msg string, keysAndValues ...interface{}) { +func NewLeveledLoggerShim(logger *zap.Logger) *LeveledLoggerShim { + return &LeveledLoggerShim{ + logger: logger, + } +} + +func (l *LeveledLoggerShim) Debug(msg string, keysAndValues ...any) { l.logger.Debug(msg, getZapFields(keysAndValues)...) } -func (l *leveledLoggerShim) Error(msg string, keysAndValues ...interface{}) { +func (l *LeveledLoggerShim) Error(msg string, keysAndValues ...any) { l.logger.Error(msg, getZapFields(keysAndValues)...) } -func (l *leveledLoggerShim) Info(msg string, keysAndValues ...interface{}) { +func (l *LeveledLoggerShim) Info(msg string, keysAndValues ...any) { l.logger.Info(msg, getZapFields(keysAndValues)...) } -func (l *leveledLoggerShim) Warn(msg string, keysAndValues ...interface{}) { +func (l *LeveledLoggerShim) Warn(msg string, keysAndValues ...any) { l.logger.Warn(msg, getZapFields(keysAndValues)...) } -func getZapFields(keysAndValues []interface{}) []zap.Field { +func getZapFields(keysAndValues []any) []zap.Field { var fields []zap.Field failed := len(keysAndValues)%2 != 0 for i := 0; i < len(keysAndValues)-1 && !failed; i += 2 { diff --git a/client/internal/http/log_test.go b/client/internal/log/shims_test.go similarity index 86% rename from client/internal/http/log_test.go rename to client/internal/log/shims_test.go index baee3b6..d028869 100644 --- a/client/internal/http/log_test.go +++ b/client/internal/log/shims_test.go @@ -1,7 +1,4 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package http +package log import ( "bufio" @@ -25,16 +22,16 @@ func (cw customWriter) Sync() error { return nil } -func TestHttpLog(t *testing.T) { +func TestLeveledLoggerShim(t *testing.T) { tests := []struct { name string - logFunc func(*leveledLoggerShim) + logFunc func(*LeveledLoggerShim) expectedStdoutSubstrs []string notExpectedStdoutSubstrs []string }{ { name: "should correctly shim into a zap.Logger", - logFunc: func(shim *leveledLoggerShim) { + logFunc: func(shim *LeveledLoggerShim) { shim.Info("info", "field", "value") shim.Warn("warn", "field", "value") shim.Error("error", "field", "value") @@ -44,7 +41,7 @@ func TestHttpLog(t *testing.T) { }, { name: "unexpected number of keys and values are specified", - logFunc: func(shim *leveledLoggerShim) { + logFunc: func(shim *LeveledLoggerShim) { shim.Info("info", "field", "value", "otherField") shim.Warn("warn", "field", "value", "otherField") shim.Error("error", "field", "value", "otherField") @@ -76,7 +73,7 @@ func TestHttpLog(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - shim := &leveledLoggerShim{ + shim := &LeveledLoggerShim{ logger: logger, } tt.logFunc(shim) diff --git a/client/internal/telemetry/telemetry.go b/client/internal/telemetry/telemetry.go index df961fe..2591a20 100644 --- a/client/internal/telemetry/telemetry.go +++ b/client/internal/telemetry/telemetry.go @@ -12,36 +12,34 @@ import ( "time" ) -// Tracer provides methods to start and stop "spans" for measuring execution times. -type Tracer interface { - // StartSpan creates and starts a new span with the given name. - StartSpan(spanName string) - // EndSpan stops the span with the given name, if it exists. - EndSpan(spanName string) - // GetTrace returns the duration of all spans recorded since the tracer was originally created, - // OR the last time GetTrace was called. After calling this method, all span data is cleared. - GetTrace() Trace +// span represents a unit of work measured by a start and end time. +type span struct { + start, end time.Time } +// spanEnder is a function which ends a particular span on a Tracer. +type spanEnder func() + +// tracerContextKey is used to store tracers on context objects. +type tracerContextKey struct{} + type tracer struct { spans map[string]*span } -var _ Tracer = (*tracer)(nil) - -func NewTracer() Tracer { +func newTracer() *tracer { return &tracer{ spans: make(map[string]*span), } } -func (r *tracer) StartSpan(spanName string) { +func (r *tracer) startSpan(spanName string) { r.spans[spanName] = &span{ start: time.Now(), } } -func (r *tracer) EndSpan(spanName string) { +func (r *tracer) endSpan(spanName string) { endTime := time.Now() if _, ok := r.spans[spanName]; !ok { return @@ -49,7 +47,9 @@ func (r *tracer) EndSpan(spanName string) { r.spans[spanName].end = endTime } -func (r *tracer) GetTrace() Trace { +// GetTrace returns the duration of all spans recorded since the tracer was originally created, +// OR the last time GetTrace was called. After calling this method, all span data is cleared. +func (r *tracer) getTrace() Trace { trace := make(Trace, len(r.spans)) for spanName, span := range r.spans { trace[spanName] = span.end.Sub(span.start) @@ -58,6 +58,35 @@ func (r *tracer) GetTrace() Trace { return trace } +// WithTracing returns a child context with tracing capabilities. +func WithTracing(ctx context.Context) context.Context { + return context.WithValue(ctx, tracerContextKey{}, newTracer()) +} + +// StartSpan starts a span with the given name. This function panics if the context +// or any of its parents wasn't created by WithTracing. +func StartSpan(ctx context.Context, spanName string) spanEnder { + tracer := mustGetTracer(ctx) + tracer.startSpan(spanName) + return func() { + tracer.endSpan(spanName) + } +} + +// GetTrace returns the currently-stored trace on the context. This function panics if +// the context or any of its parents wasn't created by WithTracing. +func GetTrace(ctx context.Context) Trace { + return mustGetTracer(ctx).getTrace() +} + +func mustGetTracer(ctx context.Context) *tracer { + tracer, ok := ctx.Value(tracerContextKey{}).(*tracer) + if !ok { + panic("Tracer is missing from context") + } + return tracer +} + // TraceStore stores a collection of traces in-memory. type TraceStore struct { traces []Trace @@ -94,23 +123,3 @@ func (t *TraceStore) GetTraceSummary() Trace { } return total } - -// NewContext returns a context with a newly initialized Tracer. -func NewContext() context.Context { - return context.WithValue(context.Background(), tracerContextKey{}, NewTracer()) -} - -// WithTracer returns a child context with a new Tracer attached. -func WithTracer(ctx context.Context, tracer Tracer) context.Context { - return context.WithValue(ctx, tracerContextKey{}, tracer) -} - -// MustGetTracer retrieves the Tracer from the specified context. -// If a Tracer is not found on the context, it panics. -func MustGetTracer(ctx context.Context) Tracer { - tracer, ok := ctx.Value(tracerContextKey{}).(Tracer) - if !ok { - panic("Tracer is missing from context") - } - return tracer -} diff --git a/client/internal/telemetry/telemetry_test.go b/client/internal/telemetry/telemetry_test.go index ad46fb4..ac92dca 100644 --- a/client/internal/telemetry/telemetry_test.go +++ b/client/internal/telemetry/telemetry_test.go @@ -14,37 +14,34 @@ import ( ) func TestTracer(t *testing.T) { - newTracer := NewTracer() - assert.NotNil(t, newTracer) - - tracer, ok := (newTracer).(*tracer) - assert.True(t, ok) + tracer := newTracer() + assert.NotNil(t, tracer) spanName := "TestSpan" - tracer.StartSpan(spanName) + tracer.startSpan(spanName) span, ok := tracer.spans[spanName] assert.True(t, ok) assert.NotNil(t, span) assert.NotZero(t, span.start) assert.Zero(t, span.end) - tracer.StartSpan(spanName) + tracer.startSpan(spanName) span, ok = tracer.spans[spanName] assert.True(t, ok) assert.NotNil(t, span) assert.NotZero(t, span.start) assert.Zero(t, span.end) - tracer.EndSpan(spanName) + tracer.endSpan(spanName) span, ok = tracer.spans[spanName] assert.True(t, ok) assert.NotNil(t, span) assert.NotZero(t, span.start) assert.NotZero(t, span.end) - tracer.EndSpan("non-existent-span") + tracer.endSpan("non-existent-span") assert.Len(t, tracer.spans, 1) - trace := tracer.GetTrace() + trace := tracer.getTrace() assert.NotNil(t, trace) assert.Empty(t, tracer.spans) assert.Len(t, trace, 1) @@ -60,34 +57,42 @@ func TestTracer(t *testing.T) { assert.Equal(t, traceString, fmt.Sprintf(`{"TestSpanMilliseconds":%d}`, duration.Milliseconds())) } -func TestNewContext(t *testing.T) { - ctx := NewContext() +func TestStartStopSpan(t *testing.T) { + ctx := WithTracing(context.Background()) assert.NotNil(t, ctx) - tracer := MustGetTracer(ctx) - assert.NotNil(t, tracer) -} + spanName := "TestSpan" + endSpan := StartSpan(ctx, spanName) + assert.NotNil(t, endSpan) -func TestWithTracer(t *testing.T) { - ctx := context.Background() - tracer := NewTracer() - assert.NotNil(t, tracer) + endSpan() + + trace := GetTrace(ctx) + assert.NotNil(t, trace) + + spanDuration, ok := trace[spanName] + assert.True(t, ok) + assert.NotZero(t, spanDuration) + + assert.Empty(t, GetTrace(ctx)) +} - ctx = WithTracer(ctx, tracer) +func TestWithTracing(t *testing.T) { + ctx := WithTracing(context.Background()) assert.NotNil(t, ctx) - tracer = MustGetTracer(ctx) + tracer := mustGetTracer(ctx) assert.NotNil(t, tracer) } func TestMustGetTracer(t *testing.T) { - ctx := NewContext() - tracer := MustGetTracer(ctx) + ctx := WithTracing(context.Background()) + tracer := mustGetTracer(ctx) assert.NotNil(t, tracer) ctx = context.Background() assert.Panics(t, func() { - MustGetTracer(ctx) + mustGetTracer(ctx) }) } diff --git a/client/internal/telemetry/types.go b/client/internal/telemetry/types.go index 67376d4..6824dbf 100644 --- a/client/internal/telemetry/types.go +++ b/client/internal/telemetry/types.go @@ -8,12 +8,6 @@ import ( "time" ) -type tracerContextKey struct{} - -type span struct { - start, end time.Time -} - // Trace is a mapping of span names to their corresponding durations as measured by a Tracer. type Trace map[string]time.Duration