diff --git a/client/cmd/client/main.go b/client/cmd/client/main.go index 3e801fe..fedcdb0 100644 --- a/client/cmd/client/main.go +++ b/client/cmd/client/main.go @@ -11,6 +11,7 @@ import ( "os/signal" "path/filepath" "syscall" + "time" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/bootstrap" "go.uber.org/zap" @@ -40,6 +41,7 @@ func init() { flag.StringVar(&bootstrapConfig.KeyFilePath, "key-file", "", "path to the file which will contain the PEM-encoded client key, referenced by the generated kubeconfig.") flag.BoolVar(&bootstrapConfig.InsecureSkipTLSVerify, "insecure-skip-tls-verify", false, "skip TLS verification when connecting to the control plane") flag.BoolVar(&bootstrapConfig.EnsureAuthorizedClient, "ensure-authorized", false, "ensure the specified kubeconfig contains an authorized clientset before bootstrapping") + flag.DurationVar(&bootstrapConfig.Timeout, "deadline", time.Minute, "deadline within which bootstrapping must succeed") flag.Parse() } @@ -73,11 +75,15 @@ func run(ctx context.Context, logger *zap.Logger) int { logger.Error("error constructing bootstrap client", zap.Error(err)) return 1 } - kubeconfigData, err := client.GetKubeletClientCredential(ctx, &bootstrapConfig) + + timeoutCtx, cancel := context.WithTimeout(ctx, bootstrapConfig.Timeout) + defer cancel() + kubeconfigData, err := client.GetKubeletClientCredential(timeoutCtx, &bootstrapConfig) if err != nil { logger.Error("error generating kubelet client credential", zap.Error(err)) return 1 } + if kubeconfigData != nil { if err := clientcmd.WriteToFile(*kubeconfigData, bootstrapConfig.KubeconfigPath); err != nil { logger.Error("error writing generated kubeconfig to disk", zap.Error(err)) diff --git a/client/internal/bootstrap/client.go b/client/internal/bootstrap/client.go index 6db3659..2fb459d 100644 --- a/client/internal/bootstrap/client.go +++ b/client/internal/bootstrap/client.go @@ -33,31 +33,28 @@ func NewClient(logger *zap.Logger) (*Client, error) { }, nil } -func (c *Client) GetKubeletClientCredential(ctx context.Context, cfg *Config) (*clientcmdapi.Config, error) { - err := c.kubeconfigValidator.Validate(cfg.KubeconfigPath, cfg.EnsureAuthorizedClient) +func (c *Client) GetKubeletClientCredential(ctx context.Context, config *Config) (*clientcmdapi.Config, error) { + err := c.kubeconfigValidator.Validate(config.KubeconfigPath, config.EnsureAuthorizedClient) if err == nil { - c.logger.Info("existing kubeconfig is valid, will skip bootstrapping", zap.String("kubeconfig", cfg.KubeconfigPath)) + c.logger.Info("existing kubeconfig is valid, will skip bootstrapping", zap.String("kubeconfig", config.KubeconfigPath)) return nil, nil } - c.logger.Info("failed to validate existing kubeconfig, will bootstrap a new client credential", zap.String("kubeconfig", cfg.KubeconfigPath), zap.Error(err)) + c.logger.Info("failed to validate existing kubeconfig, will bootstrap a new client credential", zap.String("kubeconfig", config.KubeconfigPath), zap.Error(err)) - token, err := c.getAccessToken(cfg.CustomClientID, cfg.AADResource, &cfg.AzureConfig) + token, err := c.getAccessToken(config.CustomClientID, config.AADResource, &config.AzureConfig) if err != nil { c.logger.Error("failed to generate access token for gRPC connection", zap.Error(err)) return nil, fmt.Errorf("failed to generate access token for gRPC connection: %w", err) } c.logger.Info("generated access token for gRPC connection") - serviceClient, close, err := c.getServiceClientFunc(c.logger, token, cfg) + serviceClient, closer, err := c.getServiceClientFunc(token, config) if err != nil { c.logger.Error("failed to setup bootstrap service connection", zap.Error(err)) return nil, fmt.Errorf("failed to setup bootstrap service connection: %w", err) } - defer func() { - if err := close(); err != nil { - c.logger.Error("failed to close gRPC client connection", zap.Error(err)) - } - }() + defer closer.close(c.logger) + c.logger.Info("created gRPC connection and bootstrap service client") instanceData, err := c.imdsClient.GetInstanceData(ctx) @@ -114,10 +111,10 @@ func (c *Client) GetKubeletClientCredential(ctx context.Context, cfg *Config) (* return nil, fmt.Errorf("failed to decode cert data from bootstrap server: %w", err) } kubeconfigData, err := kubeconfig.GenerateForCertAndKey(certPEM, privateKey, &kubeconfig.Config{ - APIServerFQDN: cfg.APIServerFQDN, - ClusterCAFilePath: cfg.ClusterCAFilePath, - CertFilePath: cfg.CertFilePath, - KeyFilePath: cfg.KeyFilePath, + APIServerFQDN: config.APIServerFQDN, + ClusterCAFilePath: config.ClusterCAFilePath, + CertFilePath: config.CertFilePath, + KeyFilePath: config.KeyFilePath, }) if err != nil { c.logger.Error("failed to generate kubeconfig for new client cert and key", zap.Error(err)) diff --git a/client/internal/bootstrap/client_test.go b/client/internal/bootstrap/client_test.go index 29184a8..e1e9634 100644 --- a/client/internal/bootstrap/client_test.go +++ b/client/internal/bootstrap/client_test.go @@ -111,7 +111,7 @@ var _ = Describe("Client tests", Ordered, func() { logger: logger, imdsClient: imdsClient, kubeconfigValidator: kubeconfigValidator, - getServiceClientFunc: func(_ *zap.Logger, _ string, _ *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, func() error, error) { + getServiceClientFunc: func(_ string, _ *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, closerFunc, error) { return serviceClient, func() error { return nil }, nil }, extractAccessTokenFunc: func(token *adal.ServicePrincipalToken) (string, error) { diff --git a/client/internal/bootstrap/config.go b/client/internal/bootstrap/config.go index 92406d7..ca3de70 100644 --- a/client/internal/bootstrap/config.go +++ b/client/internal/bootstrap/config.go @@ -7,23 +7,25 @@ import ( "encoding/json" "fmt" "os" + "time" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/datamodel" ) type Config struct { datamodel.AzureConfig - AzureConfigPath string `json:"azureConfigPath"` - APIServerFQDN string `json:"apiServerFqdn"` - CustomClientID string `json:"customClientId"` - NextProto string `json:"nextProto"` - AADResource string `json:"aadResource"` - ClusterCAFilePath string `json:"clusterCaFilePath"` - KubeconfigPath string `json:"kubeconfigPath"` - CertFilePath string `json:"certFilePath"` - KeyFilePath string `json:"keyFilePath"` - InsecureSkipTLSVerify bool `json:"insecureSkipTlsVerify"` - EnsureAuthorizedClient bool `json:"ensureAuthorizedClient"` + AzureConfigPath string `json:"azureConfigPath"` + APIServerFQDN string `json:"apiServerFqdn"` + CustomClientID string `json:"customClientId"` + NextProto string `json:"nextProto"` + AADResource string `json:"aadResource"` + ClusterCAFilePath string `json:"clusterCaFilePath"` + KubeconfigPath string `json:"kubeconfigPath"` + CertFilePath string `json:"certFilePath"` + KeyFilePath string `json:"keyFilePath"` + InsecureSkipTLSVerify bool `json:"insecureSkipTlsVerify"` + EnsureAuthorizedClient bool `json:"ensureAuthorizedClient"` + Timeout time.Duration `json:"timeout"` } func (c *Config) Validate() error { @@ -51,27 +53,33 @@ func (c *Config) Validate() error { if c.KeyFilePath == "" { return fmt.Errorf("key file path must be specified") } + if c.Timeout == 0 { + return fmt.Errorf("timeout must be specified") + } return c.loadAzureConfig() } func (c *Config) LoadFromFile(path string) error { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("reading config file: %w", err) - } - if err := json.Unmarshal(data, c); err != nil { - return fmt.Errorf("unmarshalling config file content: %w", err) + if err := loadJSON(path, c); err != nil { + return fmt.Errorf("loading bootstrap config file: %w", err) } return nil } func (c *Config) loadAzureConfig() error { - data, err := os.ReadFile(c.AzureConfigPath) + if err := loadJSON(c.AzureConfigPath, &c.AzureConfig); err != nil { + return fmt.Errorf("loading azure config file: %w", err) + } + return nil +} + +func loadJSON(path string, out interface{}) error { + data, err := os.ReadFile(path) if err != nil { - return fmt.Errorf("reading azure config data: %w", err) + return fmt.Errorf("reading file %s path: %w", path, err) } - if err = json.Unmarshal(data, &c.AzureConfig); err != nil { - return fmt.Errorf("unmarshalling azure config data: %w", err) + if err := json.Unmarshal(data, out); err != nil { + return fmt.Errorf("unmarshalling json data from %s: %w", path, err) } return nil } diff --git a/client/internal/bootstrap/config_test.go b/client/internal/bootstrap/config_test.go index 3e9eadc..bcd7ec2 100644 --- a/client/internal/bootstrap/config_test.go +++ b/client/internal/bootstrap/config_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "os" "path/filepath" + "time" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/datamodel" "github.com/Azure/go-autorest/autorest/azure" @@ -30,6 +31,7 @@ var _ = Describe("config tests", func() { KubeconfigPath: "path", CertFilePath: "path", KeyFilePath: "path", + Timeout: time.Minute, } }) @@ -105,12 +107,22 @@ var _ = Describe("config tests", func() { }) }) + When("timeout is not specified", func() { + It("should return an error", func() { + cfg.Timeout = 0 + err := cfg.Validate() + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("timeout must be specified")) + }) + }) + When("azure config path does not exist", func() { It("should return an error", func() { cfg.AzureConfigPath = "does/not/exist.json" err := cfg.Validate() Expect(err).ToNot(BeNil()) - Expect(err.Error()).To(ContainSubstring("reading azure config data")) + Expect(err.Error()).To(ContainSubstring("loading azure config file")) + Expect(err.Error()).To(ContainSubstring("reading file")) }) }) @@ -123,7 +135,8 @@ var _ = Describe("config tests", func() { cfg.AzureConfigPath = path err = cfg.Validate() Expect(err).ToNot(BeNil()) - Expect(err.Error()).To(ContainSubstring("unmarshalling azure config data")) + Expect(err.Error()).To(ContainSubstring("loading azure config file")) + Expect(err.Error()).To(ContainSubstring("unmarshalling json data")) }) }) @@ -162,7 +175,8 @@ var _ = Describe("config tests", func() { path := "does/not/exist.json" err := cfg.LoadFromFile(path) Expect(err).ToNot(BeNil()) - Expect(err.Error()).To(ContainSubstring("reading config file")) + Expect(err.Error()).To(ContainSubstring("loading bootstrap config file")) + Expect(err.Error()).To(ContainSubstring("reading file")) }) }) @@ -174,7 +188,8 @@ var _ = Describe("config tests", func() { Expect(err).To(BeNil()) err = cfg.LoadFromFile(path) Expect(err).ToNot(BeNil()) - Expect(err.Error()).To(ContainSubstring("unmarshalling config file content")) + Expect(err.Error()).To(ContainSubstring("loading bootstrap config file")) + Expect(err.Error()).To(ContainSubstring("unmarshalling json data")) }) }) diff --git a/client/internal/bootstrap/grpc.go b/client/internal/bootstrap/grpc.go index a209568..fa04e0e 100644 --- a/client/internal/bootstrap/grpc.go +++ b/client/internal/bootstrap/grpc.go @@ -18,23 +18,32 @@ import ( "google.golang.org/grpc/credentials/oauth" ) +// closerFunc closes a gRPC connection. +type closerFunc func() error + +func (c closerFunc) close(logger *zap.Logger) { + if err := c(); err != nil { + logger.Error("closing gRPC client connection: %s", zap.Error(err)) + } +} + // getServiceClientFunc returns a new SecureTLSBootstrapServiceClient over a gRPC connection, fake implementations given in unit tests. -type getServiceClientFunc func(logger *zap.Logger, token string, cfg *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, func() error, error) +type getServiceClientFunc func(token string, config *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, closerFunc, error) -func getServiceClient(logger *zap.Logger, token string, cfg *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, func() error, error) { - clusterCAData, err := os.ReadFile(cfg.ClusterCAFilePath) +func getServiceClient(token string, config *Config) (akssecuretlsbootstrapv1.SecureTLSBootstrapServiceClient, closerFunc, error) { + clusterCAData, err := os.ReadFile(config.ClusterCAFilePath) if err != nil { - return nil, nil, fmt.Errorf("reading cluster CA data from %s: %w", cfg.ClusterCAFilePath, err) + return nil, nil, fmt.Errorf("reading cluster CA data from %s: %w", config.ClusterCAFilePath, err) } - logger.Info("read cluster CA data", zap.String("path", cfg.ClusterCAFilePath)) - tlsConfig, err := getTLSConfig(clusterCAData, cfg.NextProto, cfg.InsecureSkipTLSVerify) + tlsConfig, err := getTLSConfig(clusterCAData, config.NextProto, config.InsecureSkipTLSVerify) if err != nil { return nil, nil, fmt.Errorf("failed to get TLS config: %w", err) } conn, err := grpc.NewClient( - fmt.Sprintf("%s:443", cfg.APIServerFQDN), + fmt.Sprintf("%s:443", config.APIServerFQDN), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), grpc.WithUserAgent(internalhttp.GetUserAgentValue()), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithPerRPCCredentials(oauth.TokenSource{ @@ -46,7 +55,6 @@ func getServiceClient(logger *zap.Logger, token string, cfg *Config) (akssecuret if err != nil { return nil, nil, fmt.Errorf("failed to dial client connection with context: %w", err) } - logger.Info("dialed TLS bootstrap server and created GRPC connection") return akssecuretlsbootstrapv1.NewSecureTLSBootstrapServiceClient(conn), conn.Close, nil } diff --git a/client/internal/bootstrap/grpc_test.go b/client/internal/bootstrap/grpc_test.go index b2ece18..7068f61 100644 --- a/client/internal/bootstrap/grpc_test.go +++ b/client/internal/bootstrap/grpc_test.go @@ -11,7 +11,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "go.uber.org/zap" "google.golang.org/grpc/test/bufconn" "github.com/Azure/aks-secure-tls-bootstrap/client/internal/testutil" @@ -20,12 +19,9 @@ import ( var _ = Describe("grpc", Ordered, func() { var ( clusterCACertPEM []byte - logger *zap.Logger ) BeforeAll(func() { - logger, _ = zap.NewDevelopment() - var err error clusterCACertPEM, _, err = testutil.GenerateCertPEM(testutil.CertTemplate{ CommonName: "hcp", @@ -39,7 +35,7 @@ var _ = Describe("grpc", Ordered, func() { Context("secureTLSBootstrapServiceClientFactory", func() { When("cluster ca data cannot be read", func() { It("should return an error", func() { - serviceClient, close, err := getServiceClient(logger, "token", &Config{ + serviceClient, close, err := getServiceClient("token", &Config{ ClusterCAFilePath: "does/not/exist.crt", NextProto: "nextProto", APIServerFQDN: "fqdn", @@ -59,7 +55,7 @@ var _ = Describe("grpc", Ordered, func() { err := os.WriteFile(caFilePath, []byte("SGVsbG8gV29ybGQh"), os.ModePerm) Expect(err).To(BeNil()) - serviceClient, close, err := getServiceClient(logger, "token", &Config{ + serviceClient, close, err := getServiceClient("token", &Config{ ClusterCAFilePath: caFilePath, NextProto: "nextProto", APIServerFQDN: "fqdn", @@ -83,7 +79,7 @@ var _ = Describe("grpc", Ordered, func() { lis := bufconn.Listen(1024) defer lis.Close() - serviceClient, close, err := getServiceClient(logger, "token", &Config{ + serviceClient, close, err := getServiceClient("token", &Config{ ClusterCAFilePath: caFilePath, NextProto: "nextProto", APIServerFQDN: lis.Addr().String(),