From 55a739f592e5cd2035a734e906609def1bf9092d Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Wed, 9 Oct 2024 17:30:30 +0400 Subject: [PATCH 1/9] Validate capi-auth-token on dqlite/remove --- cmd/cluster-agent.go | 6 +++++ cmd/consts.go | 5 ++++ cmd/init.go | 6 +++++ pkg/api/v2/consts.go | 5 ++++ pkg/api/v2/register.go | 4 +++- pkg/api/v2/remove.go | 10 ++++++-- pkg/api/v2/remove_test.go | 40 ++++++++++++++++++++++++++++---- pkg/snap/interface.go | 5 ++++ pkg/snap/mock/mock.go | 13 +++++++++++ pkg/snap/snap.go | 16 ++++++++++++- pkg/snap/snap_addons_test.go | 4 ++-- pkg/snap/snap_capi_token_test.go | 37 +++++++++++++++++++++++++++++ pkg/snap/snap_containerd_test.go | 2 +- pkg/snap/snap_files_test.go | 2 +- pkg/snap/snap_images_test.go | 2 +- pkg/snap/snap_join_test.go | 6 ++--- pkg/snap/snap_lock_test.go | 2 +- pkg/snap/snap_service_test.go | 2 +- pkg/snap/snap_sign_test.go | 2 +- pkg/snap/snap_token_test.go | 14 +++++------ pkg/snap/snap_upgrade_test.go | 4 ++-- 21 files changed, 159 insertions(+), 28 deletions(-) create mode 100644 cmd/consts.go create mode 100644 pkg/api/v2/consts.go create mode 100644 pkg/snap/snap_capi_token_test.go diff --git a/cmd/cluster-agent.go b/cmd/cluster-agent.go index 28181f7..b769f63 100644 --- a/cmd/cluster-agent.go +++ b/cmd/cluster-agent.go @@ -36,10 +36,16 @@ var clusterAgentCmd = &cobra.Command{ Long: `The MicroK8s cluster agent is an API server that orchestrates the lifecycle of a MicroK8s cluster.`, Run: func(cmd *cobra.Command, args []string) { + capiPath := os.Getenv("CAPI_PATH") + if capiPath == "" { + capiPath = capiDefaultPath + } + s := snap.NewSnap( os.Getenv("SNAP"), os.Getenv("SNAP_DATA"), os.Getenv("SNAP_COMMON"), + capiPath, snap.WithRetryApplyCNI(20, 3*time.Second), ) diff --git a/cmd/consts.go b/cmd/consts.go new file mode 100644 index 0000000..7c74383 --- /dev/null +++ b/cmd/consts.go @@ -0,0 +1,5 @@ +package cmd + +const ( + capiDefaultPath = "/capi" +) diff --git a/cmd/init.go b/cmd/init.go index 0a0661e..6c16119 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -19,10 +19,16 @@ var ( Short: "Apply MicroK8s configurations", Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { + capiPath := os.Getenv("CAPI_PATH") + if capiPath == "" { + capiPath = capiDefaultPath + } + s := snap.NewSnap( os.Getenv("SNAP"), os.Getenv("SNAP_DATA"), os.Getenv("SNAP_COMMON"), + capiPath, ) l := k8sinit.NewLauncher(s, initPreInit) diff --git a/pkg/api/v2/consts.go b/pkg/api/v2/consts.go new file mode 100644 index 0000000..368e646 --- /dev/null +++ b/pkg/api/v2/consts.go @@ -0,0 +1,5 @@ +package v2 + +const ( + CAPIAuthTokenHeader = "capi-auth-token" +) diff --git a/pkg/api/v2/register.go b/pkg/api/v2/register.go index 4edb6ad..115bd46 100644 --- a/pkg/api/v2/register.go +++ b/pkg/api/v2/register.go @@ -67,7 +67,9 @@ func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.Handl return } - if rc, err := a.RemoveFromDqlite(r.Context(), req); err != nil { + token := r.Header.Get(CAPIAuthTokenHeader) + + if rc, err := a.RemoveFromDqlite(r.Context(), req, token); err != nil { httputil.Error(w, rc, fmt.Errorf("failed to remove from dqlite: %w", err)) return } diff --git a/pkg/api/v2/remove.go b/pkg/api/v2/remove.go index e989afa..54e32cc 100644 --- a/pkg/api/v2/remove.go +++ b/pkg/api/v2/remove.go @@ -11,11 +11,17 @@ import ( // RemoveFromDqliteRequest represents a request to remove a node from the dqlite cluster. type RemoveFromDqliteRequest struct { // RemoveEndpoint is the endpoint of the node to remove from the dqlite cluster. - RemoveEndpoint string `json:"removeEndpoint"` + RemoveEndpoint string `json:"remove_endpoint"` } // RemoveFromDqlite implements the "POST /v2/dqlite/remove" endpoint and removes a node from the dqlite cluster. -func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest) (int, error) { +func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest, token string) (int, error) { + if isValid, err := a.Snap.IsCAPIAuthTokenValid(token); err != nil { + return http.StatusUnauthorized, fmt.Errorf("failed to validate CAPI auth token: %w", err) + } else if !isValid { + return http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token) + } + if err := snaputil.RemoveNodeFromDqlite(ctx, a.Snap, req.RemoveEndpoint); err != nil { return http.StatusInternalServerError, fmt.Errorf("failed to remove node from dqlite: %w", err) } diff --git a/pkg/api/v2/remove_test.go b/pkg/api/v2/remove_test.go index 5ad4aba..b1ff464 100644 --- a/pkg/api/v2/remove_test.go +++ b/pkg/api/v2/remove_test.go @@ -17,23 +17,55 @@ func TestRemove(t *testing.T) { cmdErr := errors.New("failed to run command") apiv2 := &v2.API{ Snap: &mock.Snap{ - RunCommandErr: cmdErr, + RunCommandErr: cmdErr, + CAPIAuthTokenValid: true, }, } - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}) + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") g := NewWithT(t) g.Expect(err).To(MatchError(cmdErr)) g.Expect(rc).To(Equal(http.StatusInternalServerError)) }) + t.Run("InvalidToken", func(t *testing.T) { + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenValid: false, // explicitly set to false + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(HaveOccurred()) + g.Expect(rc).To(Equal(http.StatusUnauthorized)) + }) + + t.Run("TokenFileNotFound", func(t *testing.T) { + tokenErr := errors.New("token file not found") + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenError: tokenErr, + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(MatchError(tokenErr)) + g.Expect(rc).To(Equal(http.StatusUnauthorized)) + }) + t.Run("RemovesSuccessfully", func(t *testing.T) { apiv2 := &v2.API{ - Snap: &mock.Snap{}, + Snap: &mock.Snap{ + CAPIAuthTokenValid: true, + }, } - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}) + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") g := NewWithT(t) g.Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/snap/interface.go b/pkg/snap/interface.go index b783da9..b4e4ed7 100644 --- a/pkg/snap/interface.go +++ b/pkg/snap/interface.go @@ -13,6 +13,8 @@ type Snap interface { GetSnapDataPath(parts ...string) string // GetSnapCommonPath returns the path to a file or directory in the snap's common directory. GetSnapCommonPath(parts ...string) string + // GetCAPIPath returns the path to a file or directory in the CAPI directory. + GetCAPIPath(parts ...string) string // RunCommand runs a shell command. RunCommand(ctx context.Context, commands ...string) error @@ -98,6 +100,9 @@ type Snap interface { // GetKnownToken returns the token for a known user from the known_users.csv file. GetKnownToken(username string) (string, error) + // IsCAPIAuthTokenValid returns true if token is a valid CAPI auth token. + IsCAPIAuthTokenValid(token string) (bool, error) + // SignCertificate signs the certificate signing request, and returns the certificate in PEM format. SignCertificate(ctx context.Context, csrPEM []byte) ([]byte, error) diff --git a/pkg/snap/mock/mock.go b/pkg/snap/mock/mock.go index f719342..8c5e001 100644 --- a/pkg/snap/mock/mock.go +++ b/pkg/snap/mock/mock.go @@ -34,6 +34,7 @@ type Snap struct { SnapDir string SnapDataDir string SnapCommonDir string + CAPIDir string RunCommandCalledWith []RunCommandCall RunCommandErr error @@ -85,6 +86,9 @@ type Snap struct { KubeletTokens map[string]string // map hostname to token KnownTokens map[string]string // map username to token + CAPIAuthTokenValid bool + CAPIAuthTokenError error + SignCertificateCalledWith []string // string(csrPEM) SignedCertificate string @@ -116,6 +120,11 @@ func (s *Snap) GetSnapCommonPath(parts ...string) string { return filepath.Join(append([]string{s.SnapCommonDir}, parts...)...) } +// GetCAPIPath is a mock implementation for the snap.Snap interface. +func (s *Snap) GetCAPIPath(parts ...string) string { + return filepath.Join(append([]string{s.CAPIDir}, parts...)...) +} + // RunCommand is a mock implementation for the snap.Snap interface. func (s *Snap) RunCommand(_ context.Context, commands ...string) error { s.RunCommandCalledWith = append(s.RunCommandCalledWith, RunCommandCall{Commands: commands}) @@ -320,6 +329,10 @@ func (s *Snap) GetKnownToken(username string) (string, error) { return "", fmt.Errorf("no known token for user %s", username) } +func (s *Snap) IsCAPIAuthTokenValid(token string) (bool, error) { + return s.CAPIAuthTokenValid, s.CAPIAuthTokenError +} + // RunUpgrade is a mock implementation for the snap.Snap interface. func (s *Snap) RunUpgrade(ctx context.Context, upgrade string, phase string) error { s.RunUpgradeCalledWith = append(s.RunUpgradeCalledWith, fmt.Sprintf("%s %s", upgrade, phase)) diff --git a/pkg/snap/snap.go b/pkg/snap/snap.go index 2e9f10d..489770c 100644 --- a/pkg/snap/snap.go +++ b/pkg/snap/snap.go @@ -23,6 +23,7 @@ type snap struct { snapDir string snapDataDir string snapCommonDir string + capiPath string runCommand func(context.Context, ...string) error clusterTokensMu sync.Mutex @@ -36,11 +37,12 @@ type snap struct { // NewSnap creates a new interface with the MicroK8s snap. // NewSnap accepts the $SNAP, $SNAP_DATA and $SNAP_COMMON, directories, and a number of options. -func NewSnap(snapDir, snapDataDir, snapCommonDir string, options ...func(s *snap)) Snap { +func NewSnap(snapDir, snapDataDir, snapCommonDir, capiPath string, options ...func(s *snap)) Snap { s := &snap{ snapDir: snapDir, snapDataDir: snapDataDir, snapCommonDir: snapCommonDir, + capiPath: capiPath, runCommand: util.RunCommand, } @@ -65,6 +67,9 @@ func (s *snap) GetSnapDataPath(parts ...string) string { func (s *snap) GetSnapCommonPath(parts ...string) string { return filepath.Join(append([]string{s.snapCommonDir}, parts...)...) } +func (s *snap) GetCAPIPath(parts ...string) string { + return filepath.Join(append([]string{s.capiPath}, parts...)...) +} func (s *snap) GetGroupName() string { if s.isStrict() { @@ -331,6 +336,15 @@ func (s *snap) GetKnownToken(username string) (string, error) { return "", fmt.Errorf("no known token found for user %s", username) } +// IsCAPIAuthTokenValid checks if the given CAPI auth token is valid. +func (s *snap) IsCAPIAuthTokenValid(token string) (bool, error) { + contents, err := util.ReadFile(s.GetCAPIPath("etc", "token")) + if err != nil { + return false, fmt.Errorf("failed to read token file: %w", err) + } + return strings.TrimSpace(contents) == token, nil +} + func (s *snap) SignCertificate(ctx context.Context, csrPEM []byte) ([]byte, error) { // TODO: consider using crypto/x509 for this instead of relying on openssl commands. // NOTE(neoaggelos): x509.CreateCertificate() has some hardcoded fields that are incompatible with MicroK8s. diff --git a/pkg/snap/snap_addons_test.go b/pkg/snap/snap_addons_test.go index f2bd073..b106839 100644 --- a/pkg/snap/snap_addons_test.go +++ b/pkg/snap/snap_addons_test.go @@ -12,7 +12,7 @@ import ( func TestAddons(t *testing.T) { t.Run("EnableDisable", func(t *testing.T) { runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) s.EnableAddon(context.Background(), "dns") s.EnableAddon(context.Background(), "dns", "10.0.0.2") @@ -32,7 +32,7 @@ func TestAddons(t *testing.T) { t.Run("AddRepository", func(t *testing.T) { runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) s.AddAddonsRepository(context.Background(), "core", "/snap/microk8s/current/addons/core", "", false) s.AddAddonsRepository(context.Background(), "core", "/snap/microk8s/current/addons/core", "", true) diff --git a/pkg/snap/snap_capi_token_test.go b/pkg/snap/snap_capi_token_test.go new file mode 100644 index 0000000..dce3c74 --- /dev/null +++ b/pkg/snap/snap_capi_token_test.go @@ -0,0 +1,37 @@ +package snap_test + +import ( + "os" + "path/filepath" + "testing" + + . "github.com/onsi/gomega" + + "github.com/canonical/microk8s-cluster-agent/pkg/snap" +) + +func TestCAPIAuthToken(t *testing.T) { + capiTestPath := "./capi-test" + os.RemoveAll(capiTestPath) + s := snap.NewSnap("", "", "", capiTestPath) + token := "token123" + + g := NewWithT(t) + + isValid, err := s.IsCAPIAuthTokenValid(token) + g.Expect(err).To(MatchError(os.ErrNotExist)) + g.Expect(isValid).To(BeFalse()) + + capiEtc := filepath.Join(capiTestPath, "etc") + defer os.RemoveAll(capiTestPath) + g.Expect(os.MkdirAll(capiEtc, 0755)).To(Succeed()) + g.Expect(os.WriteFile("./capi-test/etc/token", []byte(token), 0600)).To(Succeed()) + + isValid, err = s.IsCAPIAuthTokenValid("random-token") + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(isValid).To(BeFalse()) + + isValid, err = s.IsCAPIAuthTokenValid(token) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(isValid).To(BeTrue()) +} diff --git a/pkg/snap/snap_containerd_test.go b/pkg/snap/snap_containerd_test.go index 39826de..d6adbb5 100644 --- a/pkg/snap/snap_containerd_test.go +++ b/pkg/snap/snap_containerd_test.go @@ -14,7 +14,7 @@ func TestUpdateContainerdRegistryConfigs(t *testing.T) { } defer os.RemoveAll("testdata/args") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") t.Run("Mirror", func(t *testing.T) { g := NewWithT(t) diff --git a/pkg/snap/snap_files_test.go b/pkg/snap/snap_files_test.go index ba9b606..7a906ee 100644 --- a/pkg/snap/snap_files_test.go +++ b/pkg/snap/snap_files_test.go @@ -47,7 +47,7 @@ func TestFiles(t *testing.T) { defer os.RemoveAll(filepath.Dir(file)) } - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") for _, tc := range []struct { name string diff --git a/pkg/snap/snap_images_test.go b/pkg/snap/snap_images_test.go index fe72bf8..2d92473 100644 --- a/pkg/snap/snap_images_test.go +++ b/pkg/snap/snap_images_test.go @@ -37,7 +37,7 @@ func TestImportImage(t *testing.T) { os.Remove("testdata/arguments") }() mockRunner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata/common", snap.WithCommandRunner(mockRunner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata/common", "", snap.WithCommandRunner(mockRunner.Run)) g := NewWithT(t) err := s.ImportImage(context.Background(), bytes.NewBufferString("IMAGEDATA")) diff --git a/pkg/snap/snap_join_test.go b/pkg/snap/snap_join_test.go index 1d311ed..815af8f 100644 --- a/pkg/snap/snap_join_test.go +++ b/pkg/snap/snap_join_test.go @@ -15,7 +15,7 @@ func TestJoinCluster(t *testing.T) { t.Run("PropagateError", func(t *testing.T) { g := NewWithT(t) runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) runner.Err = fmt.Errorf("some error") err := s.JoinCluster(context.Background(), "some-url", false) @@ -26,7 +26,7 @@ func TestJoinCluster(t *testing.T) { t.Run("ControlPlane", func(t *testing.T) { g := NewWithT(t) runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) err := s.JoinCluster(context.Background(), "10.10.10.10:25000/token/hash", false) g.Expect(err).To(BeNil()) @@ -36,7 +36,7 @@ func TestJoinCluster(t *testing.T) { t.Run("Worker", func(t *testing.T) { g := NewWithT(t) runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) err := s.JoinCluster(context.Background(), "10.10.10.10:25000/token/hash", true) g.Expect(err).To(BeNil()) diff --git a/pkg/snap/snap_lock_test.go b/pkg/snap/snap_lock_test.go index a38a2f9..25c5615 100644 --- a/pkg/snap/snap_lock_test.go +++ b/pkg/snap/snap_lock_test.go @@ -9,7 +9,7 @@ import ( ) func TestLock(t *testing.T) { - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") if err := os.MkdirAll("testdata/var/lock", 0755); err != nil { t.Fatalf("Failed to create directory: %s", err) } diff --git a/pkg/snap/snap_service_test.go b/pkg/snap/snap_service_test.go index e4ecc50..223911c 100644 --- a/pkg/snap/snap_service_test.go +++ b/pkg/snap/snap_service_test.go @@ -11,7 +11,7 @@ import ( func TestServiceRestart(t *testing.T) { mockRunner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(mockRunner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(mockRunner.Run)) t.Run("NoKubelite", func(t *testing.T) { for _, tc := range []struct { diff --git a/pkg/snap/snap_sign_test.go b/pkg/snap/snap_sign_test.go index 20f9941..153a818 100644 --- a/pkg/snap/snap_sign_test.go +++ b/pkg/snap/snap_sign_test.go @@ -35,7 +35,7 @@ func TestSignCertificate(t *testing.T) { os.Remove("testdata/arguments") }() mockRunner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(mockRunner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(mockRunner.Run)) g := NewWithT(t) b, err := s.SignCertificate(context.Background(), []byte("MOCK CSR")) diff --git a/pkg/snap/snap_token_test.go b/pkg/snap/snap_token_test.go index cfe163d..109875f 100644 --- a/pkg/snap/snap_token_test.go +++ b/pkg/snap/snap_token_test.go @@ -13,7 +13,7 @@ import ( func TestClusterTokens(t *testing.T) { os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") t.Run("MissingTokensFile", func(t *testing.T) { if s.ConsumeClusterToken("token1") { t.Fatal("Expected token1 to not be valid, but it is") @@ -94,7 +94,7 @@ func TestPersistentClusterToken(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") if err := s.AddPersistentClusterToken("my-token"); err != nil { t.Fatalf("Failed to add persistent cluster token: %s", err) } @@ -117,7 +117,7 @@ func TestCertificateRequestTokens(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") if err := s.AddCertificateRequestToken("my-token"); err != nil { t.Fatalf("Failed to add certificate request token: %s", err) } @@ -151,7 +151,7 @@ func TestCallbackTokens(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") if err := s.AddCallbackToken("ip:port", "my-token"); err != nil { t.Fatalf("Failed to add certificate request token: %s", err) } @@ -169,7 +169,7 @@ func TestSelfCallbackToken(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") token, err := s.GetOrCreateSelfCallbackToken() if err != nil { t.Fatalf("Failed to configure callback token: %q", err) @@ -194,7 +194,7 @@ func TestKnownTokens(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata") + s := snap.NewSnap("testdata", "testdata", "testdata", "") if token, err := s.GetKnownToken("user"); token != "" || err == nil { t.Fatalf("Expected an empty token and an error, but found token %s and error %s", token, err) } @@ -273,7 +273,7 @@ func TestStrictGroup(t *testing.T) { if err := os.WriteFile("testdata/meta/snapcraft.yaml", []byte(fmt.Sprintf("confinement: %s", tc.confinement)), 0660); err != nil { t.Fatalf("Failed to create test file: %s", err) } - group := snap.NewSnap("testdata", "testdata", "testdata").GetGroupName() + group := snap.NewSnap("testdata", "testdata", "testdata", "").GetGroupName() if tc.group != group { t.Fatalf("Expected group to be %q but it was %q instead", tc.group, group) } diff --git a/pkg/snap/snap_upgrade_test.go b/pkg/snap/snap_upgrade_test.go index 5363cf4..20761ac 100644 --- a/pkg/snap/snap_upgrade_test.go +++ b/pkg/snap/snap_upgrade_test.go @@ -31,7 +31,7 @@ func TestRunUpgrade(t *testing.T) { defer os.RemoveAll("testdata/upgrade-scripts") runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) t.Run("Invalid", func(t *testing.T) { for _, tc := range []struct { @@ -58,7 +58,7 @@ func TestRunUpgrade(t *testing.T) { t.Run(phase, func(t *testing.T) { runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) err := s.RunUpgrade(context.Background(), "001-custom-upgrade", phase) if err != nil { From d41addbc4b414f444a8c925bd5f20326f65f0856 Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Wed, 9 Oct 2024 17:31:40 +0400 Subject: [PATCH 2/9] Fix lint --- pkg/api/v2/consts.go | 1 + pkg/snap/mock/mock.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/api/v2/consts.go b/pkg/api/v2/consts.go index 368e646..2fb8a82 100644 --- a/pkg/api/v2/consts.go +++ b/pkg/api/v2/consts.go @@ -1,5 +1,6 @@ package v2 const ( + // CAPIAuthTokenHeader is the header used to pass the CAPI auth token. CAPIAuthTokenHeader = "capi-auth-token" ) diff --git a/pkg/snap/mock/mock.go b/pkg/snap/mock/mock.go index 8c5e001..bcc267c 100644 --- a/pkg/snap/mock/mock.go +++ b/pkg/snap/mock/mock.go @@ -329,6 +329,7 @@ func (s *Snap) GetKnownToken(username string) (string, error) { return "", fmt.Errorf("no known token for user %s", username) } +// IsCAPIAuthTokenValid is a mock implementation for the snap.Snap interface. func (s *Snap) IsCAPIAuthTokenValid(token string) (bool, error) { return s.CAPIAuthTokenValid, s.CAPIAuthTokenError } From f5d5ad2357683cda7900e6fae12380b577867157 Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Thu, 10 Oct 2024 16:45:14 +0400 Subject: [PATCH 3/9] Address comments and issues --- cmd/cluster-agent.go | 6 ------ cmd/consts.go | 5 ----- cmd/init.go | 6 ------ pkg/api/v2/remove.go | 7 +++++-- pkg/snap/options.go | 7 +++++++ pkg/snap/snap.go | 8 ++++++-- pkg/snap/snap_addons_test.go | 4 ++-- pkg/snap/snap_capi_token_test.go | 2 +- pkg/snap/snap_containerd_test.go | 2 +- pkg/snap/snap_files_test.go | 2 +- pkg/snap/snap_images_test.go | 2 +- pkg/snap/snap_join_test.go | 6 +++--- pkg/snap/snap_lock_test.go | 2 +- pkg/snap/snap_service_test.go | 2 +- pkg/snap/snap_sign_test.go | 2 +- pkg/snap/snap_token_test.go | 14 +++++++------- pkg/snap/snap_upgrade_test.go | 4 ++-- 17 files changed, 39 insertions(+), 42 deletions(-) delete mode 100644 cmd/consts.go diff --git a/cmd/cluster-agent.go b/cmd/cluster-agent.go index b769f63..28181f7 100644 --- a/cmd/cluster-agent.go +++ b/cmd/cluster-agent.go @@ -36,16 +36,10 @@ var clusterAgentCmd = &cobra.Command{ Long: `The MicroK8s cluster agent is an API server that orchestrates the lifecycle of a MicroK8s cluster.`, Run: func(cmd *cobra.Command, args []string) { - capiPath := os.Getenv("CAPI_PATH") - if capiPath == "" { - capiPath = capiDefaultPath - } - s := snap.NewSnap( os.Getenv("SNAP"), os.Getenv("SNAP_DATA"), os.Getenv("SNAP_COMMON"), - capiPath, snap.WithRetryApplyCNI(20, 3*time.Second), ) diff --git a/cmd/consts.go b/cmd/consts.go deleted file mode 100644 index 7c74383..0000000 --- a/cmd/consts.go +++ /dev/null @@ -1,5 +0,0 @@ -package cmd - -const ( - capiDefaultPath = "/capi" -) diff --git a/cmd/init.go b/cmd/init.go index 6c16119..0a0661e 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -19,16 +19,10 @@ var ( Short: "Apply MicroK8s configurations", Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { - capiPath := os.Getenv("CAPI_PATH") - if capiPath == "" { - capiPath = capiDefaultPath - } - s := snap.NewSnap( os.Getenv("SNAP"), os.Getenv("SNAP_DATA"), os.Getenv("SNAP_COMMON"), - capiPath, ) l := k8sinit.NewLauncher(s, initPreInit) diff --git a/pkg/api/v2/remove.go b/pkg/api/v2/remove.go index 54e32cc..c7d6119 100644 --- a/pkg/api/v2/remove.go +++ b/pkg/api/v2/remove.go @@ -16,9 +16,12 @@ type RemoveFromDqliteRequest struct { // RemoveFromDqlite implements the "POST /v2/dqlite/remove" endpoint and removes a node from the dqlite cluster. func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest, token string) (int, error) { - if isValid, err := a.Snap.IsCAPIAuthTokenValid(token); err != nil { + isValid, err := a.Snap.IsCAPIAuthTokenValid(token) + if err != nil { return http.StatusUnauthorized, fmt.Errorf("failed to validate CAPI auth token: %w", err) - } else if !isValid { + } + + if !isValid { return http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token) } diff --git a/pkg/snap/options.go b/pkg/snap/options.go index a439e41..f58456c 100644 --- a/pkg/snap/options.go +++ b/pkg/snap/options.go @@ -22,3 +22,10 @@ func WithCommandRunner(f func(context.Context, ...string) error) func(s *snap) { s.runCommand = f } } + +// WithCAPIPath configures the path to the CAPI directory. +func WithCAPIPath(path string) func(s *snap) { + return func(s *snap) { + s.capiPath = path + } +} diff --git a/pkg/snap/snap.go b/pkg/snap/snap.go index 489770c..3f63c48 100644 --- a/pkg/snap/snap.go +++ b/pkg/snap/snap.go @@ -35,14 +35,18 @@ type snap struct { applyCNIBackoff time.Duration } +const ( + defaultCAPIPath = "/capi" +) + // NewSnap creates a new interface with the MicroK8s snap. // NewSnap accepts the $SNAP, $SNAP_DATA and $SNAP_COMMON, directories, and a number of options. -func NewSnap(snapDir, snapDataDir, snapCommonDir, capiPath string, options ...func(s *snap)) Snap { +func NewSnap(snapDir, snapDataDir, snapCommonDir string, options ...func(s *snap)) Snap { s := &snap{ snapDir: snapDir, snapDataDir: snapDataDir, snapCommonDir: snapCommonDir, - capiPath: capiPath, + capiPath: defaultCAPIPath, runCommand: util.RunCommand, } diff --git a/pkg/snap/snap_addons_test.go b/pkg/snap/snap_addons_test.go index b106839..f2bd073 100644 --- a/pkg/snap/snap_addons_test.go +++ b/pkg/snap/snap_addons_test.go @@ -12,7 +12,7 @@ import ( func TestAddons(t *testing.T) { t.Run("EnableDisable", func(t *testing.T) { runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) s.EnableAddon(context.Background(), "dns") s.EnableAddon(context.Background(), "dns", "10.0.0.2") @@ -32,7 +32,7 @@ func TestAddons(t *testing.T) { t.Run("AddRepository", func(t *testing.T) { runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) s.AddAddonsRepository(context.Background(), "core", "/snap/microk8s/current/addons/core", "", false) s.AddAddonsRepository(context.Background(), "core", "/snap/microk8s/current/addons/core", "", true) diff --git a/pkg/snap/snap_capi_token_test.go b/pkg/snap/snap_capi_token_test.go index dce3c74..23bdb96 100644 --- a/pkg/snap/snap_capi_token_test.go +++ b/pkg/snap/snap_capi_token_test.go @@ -13,7 +13,7 @@ import ( func TestCAPIAuthToken(t *testing.T) { capiTestPath := "./capi-test" os.RemoveAll(capiTestPath) - s := snap.NewSnap("", "", "", capiTestPath) + s := snap.NewSnap("", "", "", snap.WithCAPIPath(capiTestPath)) token := "token123" g := NewWithT(t) diff --git a/pkg/snap/snap_containerd_test.go b/pkg/snap/snap_containerd_test.go index d6adbb5..39826de 100644 --- a/pkg/snap/snap_containerd_test.go +++ b/pkg/snap/snap_containerd_test.go @@ -14,7 +14,7 @@ func TestUpdateContainerdRegistryConfigs(t *testing.T) { } defer os.RemoveAll("testdata/args") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") t.Run("Mirror", func(t *testing.T) { g := NewWithT(t) diff --git a/pkg/snap/snap_files_test.go b/pkg/snap/snap_files_test.go index 7a906ee..ba9b606 100644 --- a/pkg/snap/snap_files_test.go +++ b/pkg/snap/snap_files_test.go @@ -47,7 +47,7 @@ func TestFiles(t *testing.T) { defer os.RemoveAll(filepath.Dir(file)) } - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") for _, tc := range []struct { name string diff --git a/pkg/snap/snap_images_test.go b/pkg/snap/snap_images_test.go index 2d92473..fe72bf8 100644 --- a/pkg/snap/snap_images_test.go +++ b/pkg/snap/snap_images_test.go @@ -37,7 +37,7 @@ func TestImportImage(t *testing.T) { os.Remove("testdata/arguments") }() mockRunner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata/common", "", snap.WithCommandRunner(mockRunner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata/common", snap.WithCommandRunner(mockRunner.Run)) g := NewWithT(t) err := s.ImportImage(context.Background(), bytes.NewBufferString("IMAGEDATA")) diff --git a/pkg/snap/snap_join_test.go b/pkg/snap/snap_join_test.go index 815af8f..1d311ed 100644 --- a/pkg/snap/snap_join_test.go +++ b/pkg/snap/snap_join_test.go @@ -15,7 +15,7 @@ func TestJoinCluster(t *testing.T) { t.Run("PropagateError", func(t *testing.T) { g := NewWithT(t) runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) runner.Err = fmt.Errorf("some error") err := s.JoinCluster(context.Background(), "some-url", false) @@ -26,7 +26,7 @@ func TestJoinCluster(t *testing.T) { t.Run("ControlPlane", func(t *testing.T) { g := NewWithT(t) runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) err := s.JoinCluster(context.Background(), "10.10.10.10:25000/token/hash", false) g.Expect(err).To(BeNil()) @@ -36,7 +36,7 @@ func TestJoinCluster(t *testing.T) { t.Run("Worker", func(t *testing.T) { g := NewWithT(t) runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) err := s.JoinCluster(context.Background(), "10.10.10.10:25000/token/hash", true) g.Expect(err).To(BeNil()) diff --git a/pkg/snap/snap_lock_test.go b/pkg/snap/snap_lock_test.go index 25c5615..a38a2f9 100644 --- a/pkg/snap/snap_lock_test.go +++ b/pkg/snap/snap_lock_test.go @@ -9,7 +9,7 @@ import ( ) func TestLock(t *testing.T) { - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") if err := os.MkdirAll("testdata/var/lock", 0755); err != nil { t.Fatalf("Failed to create directory: %s", err) } diff --git a/pkg/snap/snap_service_test.go b/pkg/snap/snap_service_test.go index 223911c..e4ecc50 100644 --- a/pkg/snap/snap_service_test.go +++ b/pkg/snap/snap_service_test.go @@ -11,7 +11,7 @@ import ( func TestServiceRestart(t *testing.T) { mockRunner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(mockRunner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(mockRunner.Run)) t.Run("NoKubelite", func(t *testing.T) { for _, tc := range []struct { diff --git a/pkg/snap/snap_sign_test.go b/pkg/snap/snap_sign_test.go index 153a818..20f9941 100644 --- a/pkg/snap/snap_sign_test.go +++ b/pkg/snap/snap_sign_test.go @@ -35,7 +35,7 @@ func TestSignCertificate(t *testing.T) { os.Remove("testdata/arguments") }() mockRunner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(mockRunner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(mockRunner.Run)) g := NewWithT(t) b, err := s.SignCertificate(context.Background(), []byte("MOCK CSR")) diff --git a/pkg/snap/snap_token_test.go b/pkg/snap/snap_token_test.go index 109875f..cfe163d 100644 --- a/pkg/snap/snap_token_test.go +++ b/pkg/snap/snap_token_test.go @@ -13,7 +13,7 @@ import ( func TestClusterTokens(t *testing.T) { os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") t.Run("MissingTokensFile", func(t *testing.T) { if s.ConsumeClusterToken("token1") { t.Fatal("Expected token1 to not be valid, but it is") @@ -94,7 +94,7 @@ func TestPersistentClusterToken(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") if err := s.AddPersistentClusterToken("my-token"); err != nil { t.Fatalf("Failed to add persistent cluster token: %s", err) } @@ -117,7 +117,7 @@ func TestCertificateRequestTokens(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") if err := s.AddCertificateRequestToken("my-token"); err != nil { t.Fatalf("Failed to add certificate request token: %s", err) } @@ -151,7 +151,7 @@ func TestCallbackTokens(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") if err := s.AddCallbackToken("ip:port", "my-token"); err != nil { t.Fatalf("Failed to add certificate request token: %s", err) } @@ -169,7 +169,7 @@ func TestSelfCallbackToken(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") token, err := s.GetOrCreateSelfCallbackToken() if err != nil { t.Fatalf("Failed to configure callback token: %q", err) @@ -194,7 +194,7 @@ func TestKnownTokens(t *testing.T) { t.Fatalf("Failed to create test directory: %s", err) } defer os.RemoveAll("testdata/credentials") - s := snap.NewSnap("testdata", "testdata", "testdata", "") + s := snap.NewSnap("testdata", "testdata", "testdata") if token, err := s.GetKnownToken("user"); token != "" || err == nil { t.Fatalf("Expected an empty token and an error, but found token %s and error %s", token, err) } @@ -273,7 +273,7 @@ func TestStrictGroup(t *testing.T) { if err := os.WriteFile("testdata/meta/snapcraft.yaml", []byte(fmt.Sprintf("confinement: %s", tc.confinement)), 0660); err != nil { t.Fatalf("Failed to create test file: %s", err) } - group := snap.NewSnap("testdata", "testdata", "testdata", "").GetGroupName() + group := snap.NewSnap("testdata", "testdata", "testdata").GetGroupName() if tc.group != group { t.Fatalf("Expected group to be %q but it was %q instead", tc.group, group) } diff --git a/pkg/snap/snap_upgrade_test.go b/pkg/snap/snap_upgrade_test.go index 20761ac..5363cf4 100644 --- a/pkg/snap/snap_upgrade_test.go +++ b/pkg/snap/snap_upgrade_test.go @@ -31,7 +31,7 @@ func TestRunUpgrade(t *testing.T) { defer os.RemoveAll("testdata/upgrade-scripts") runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) t.Run("Invalid", func(t *testing.T) { for _, tc := range []struct { @@ -58,7 +58,7 @@ func TestRunUpgrade(t *testing.T) { t.Run(phase, func(t *testing.T) { runner := &utiltest.MockRunner{} - s := snap.NewSnap("testdata", "testdata", "testdata", "", snap.WithCommandRunner(runner.Run)) + s := snap.NewSnap("testdata", "testdata", "testdata", snap.WithCommandRunner(runner.Run)) err := s.RunUpgrade(context.Background(), "001-custom-upgrade", phase) if err != nil { From 22a3bb7e15522590a85aa5b549ba3b6f181d7315 Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Fri, 11 Oct 2024 10:37:10 +0400 Subject: [PATCH 4/9] Address comments and issues --- pkg/api/v2/remove.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/api/v2/remove.go b/pkg/api/v2/remove.go index c7d6119..0fd00c0 100644 --- a/pkg/api/v2/remove.go +++ b/pkg/api/v2/remove.go @@ -18,7 +18,7 @@ type RemoveFromDqliteRequest struct { func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest, token string) (int, error) { isValid, err := a.Snap.IsCAPIAuthTokenValid(token) if err != nil { - return http.StatusUnauthorized, fmt.Errorf("failed to validate CAPI auth token: %w", err) + return http.StatusInternalServerError, fmt.Errorf("failed to validate CAPI auth token: %w", err) } if !isValid { From e711313ea64bbb484e281353edea66808222acb8 Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Fri, 11 Oct 2024 11:56:21 +0400 Subject: [PATCH 5/9] Fix test --- pkg/api/v2/remove_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/api/v2/remove_test.go b/pkg/api/v2/remove_test.go index b1ff464..2371b4d 100644 --- a/pkg/api/v2/remove_test.go +++ b/pkg/api/v2/remove_test.go @@ -55,7 +55,7 @@ func TestRemove(t *testing.T) { g := NewWithT(t) g.Expect(err).To(MatchError(tokenErr)) - g.Expect(rc).To(Equal(http.StatusUnauthorized)) + g.Expect(rc).To(Equal(http.StatusInternalServerError)) }) t.Run("RemovesSuccessfully", func(t *testing.T) { From 2bcaeca05ce83c25607faae23b53b8fd7baa27ad Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Fri, 11 Oct 2024 13:06:44 +0400 Subject: [PATCH 6/9] Add middleware for capi auth token --- pkg/api/v2/consts.go | 6 --- pkg/api/v2/register.go | 41 ++++++++-------- pkg/api/v2/remove.go | 11 +---- pkg/api/v2/remove_test.go | 40 ++-------------- pkg/middleware/capi.go | 39 +++++++++++++++ pkg/middleware/capi_test.go | 96 +++++++++++++++++++++++++++++++++++++ pkg/server/server.go | 7 ++- 7 files changed, 168 insertions(+), 72 deletions(-) delete mode 100644 pkg/api/v2/consts.go create mode 100644 pkg/middleware/capi.go create mode 100644 pkg/middleware/capi_test.go diff --git a/pkg/api/v2/consts.go b/pkg/api/v2/consts.go deleted file mode 100644 index 2fb8a82..0000000 --- a/pkg/api/v2/consts.go +++ /dev/null @@ -1,6 +0,0 @@ -package v2 - -const ( - // CAPIAuthTokenHeader is the header used to pass the CAPI auth token. - CAPIAuthTokenHeader = "capi-auth-token" -) diff --git a/pkg/api/v2/register.go b/pkg/api/v2/register.go index 115bd46..3e676da 100644 --- a/pkg/api/v2/register.go +++ b/pkg/api/v2/register.go @@ -5,13 +5,14 @@ import ( "net/http" "github.com/canonical/microk8s-cluster-agent/pkg/httputil" + "github.com/canonical/microk8s-cluster-agent/pkg/snap" ) // HTTPPrefix is the prefix for all v2 API routes. const HTTPPrefix = "/cluster/api/v2.0" // RegisterServer registers the Cluster API v2 endpoints on an HTTP server. -func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.HandlerFunc) http.HandlerFunc) { +func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.HandlerFunc) http.HandlerFunc, capiAuthMiddleware func(http.HandlerFunc, snap.Snap) http.HandlerFunc) { // POST v2/join server.HandleFunc(fmt.Sprintf("%s/join", HTTPPrefix), middleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -55,25 +56,27 @@ func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.Handl })) // POST v2/dqlite/remove - server.HandleFunc(fmt.Sprintf("%s/dqlite/remove", HTTPPrefix), middleware(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - req := RemoveFromDqliteRequest{} - if err := httputil.UnmarshalJSON(r, &req); err != nil { - httputil.Error(w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal JSON: %w", err)) - return - } + server.HandleFunc(fmt.Sprintf("%s/dqlite/remove", HTTPPrefix), + middleware(func(w http.ResponseWriter, r *http.Request) { + capiAuthMiddleware(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } - token := r.Header.Get(CAPIAuthTokenHeader) + req := RemoveFromDqliteRequest{} + if err := httputil.UnmarshalJSON(r, &req); err != nil { + httputil.Error(w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal JSON: %w", err)) + return + } - if rc, err := a.RemoveFromDqlite(r.Context(), req, token); err != nil { - httputil.Error(w, rc, fmt.Errorf("failed to remove from dqlite: %w", err)) - return - } + if rc, err := a.RemoveFromDqlite(r.Context(), req); err != nil { + httputil.Error(w, rc, fmt.Errorf("failed to remove from dqlite: %w", err)) + return + } - httputil.Response(w, nil) - })) + httputil.Response(w, nil) + }, a.Snap) + }), + ) } diff --git a/pkg/api/v2/remove.go b/pkg/api/v2/remove.go index 0fd00c0..fe6d9fa 100644 --- a/pkg/api/v2/remove.go +++ b/pkg/api/v2/remove.go @@ -15,16 +15,7 @@ type RemoveFromDqliteRequest struct { } // RemoveFromDqlite implements the "POST /v2/dqlite/remove" endpoint and removes a node from the dqlite cluster. -func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest, token string) (int, error) { - isValid, err := a.Snap.IsCAPIAuthTokenValid(token) - if err != nil { - return http.StatusInternalServerError, fmt.Errorf("failed to validate CAPI auth token: %w", err) - } - - if !isValid { - return http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token) - } - +func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest) (int, error) { if err := snaputil.RemoveNodeFromDqlite(ctx, a.Snap, req.RemoveEndpoint); err != nil { return http.StatusInternalServerError, fmt.Errorf("failed to remove node from dqlite: %w", err) } diff --git a/pkg/api/v2/remove_test.go b/pkg/api/v2/remove_test.go index 2371b4d..5ad4aba 100644 --- a/pkg/api/v2/remove_test.go +++ b/pkg/api/v2/remove_test.go @@ -17,55 +17,23 @@ func TestRemove(t *testing.T) { cmdErr := errors.New("failed to run command") apiv2 := &v2.API{ Snap: &mock.Snap{ - RunCommandErr: cmdErr, - CAPIAuthTokenValid: true, + RunCommandErr: cmdErr, }, } - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}) g := NewWithT(t) g.Expect(err).To(MatchError(cmdErr)) g.Expect(rc).To(Equal(http.StatusInternalServerError)) }) - t.Run("InvalidToken", func(t *testing.T) { - apiv2 := &v2.API{ - Snap: &mock.Snap{ - CAPIAuthTokenValid: false, // explicitly set to false - }, - } - - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") - - g := NewWithT(t) - g.Expect(err).To(HaveOccurred()) - g.Expect(rc).To(Equal(http.StatusUnauthorized)) - }) - - t.Run("TokenFileNotFound", func(t *testing.T) { - tokenErr := errors.New("token file not found") - apiv2 := &v2.API{ - Snap: &mock.Snap{ - CAPIAuthTokenError: tokenErr, - }, - } - - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") - - g := NewWithT(t) - g.Expect(err).To(MatchError(tokenErr)) - g.Expect(rc).To(Equal(http.StatusInternalServerError)) - }) - t.Run("RemovesSuccessfully", func(t *testing.T) { apiv2 := &v2.API{ - Snap: &mock.Snap{ - CAPIAuthTokenValid: true, - }, + Snap: &mock.Snap{}, } - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}) g := NewWithT(t) g.Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/middleware/capi.go b/pkg/middleware/capi.go new file mode 100644 index 0000000..e7555b0 --- /dev/null +++ b/pkg/middleware/capi.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/canonical/microk8s-cluster-agent/pkg/httputil" + "github.com/canonical/microk8s-cluster-agent/pkg/snap" +) + +const ( + // CAPIAuthTokenHeader is the header used to pass the CAPI auth token. + CAPIAuthTokenHeader = "capi-auth-token" +) + +func CAPIAuthToken(next http.HandlerFunc, snap snap.Snap) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get(CAPIAuthTokenHeader) + fmt.Println(r.Header, "-->", r.Header.Get(CAPIAuthTokenHeader)) + fmt.Println("token", token) + if token == "" { + httputil.Error(w, http.StatusUnauthorized, fmt.Errorf("missing CAPI auth token")) + return + } + + isValid, err := snap.IsCAPIAuthTokenValid(token) + if err != nil { + httputil.Error(w, http.StatusInternalServerError, fmt.Errorf("failed to validate CAPI auth token: %w", err)) + return + } + + if !isValid { + httputil.Error(w, http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token)) + return + } + + next.ServeHTTP(w, r) + } +} diff --git a/pkg/middleware/capi_test.go b/pkg/middleware/capi_test.go new file mode 100644 index 0000000..b589ce9 --- /dev/null +++ b/pkg/middleware/capi_test.go @@ -0,0 +1,96 @@ +package middleware_test + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + . "github.com/onsi/gomega" + + "github.com/canonical/microk8s-cluster-agent/pkg/middleware" + "github.com/canonical/microk8s-cluster-agent/pkg/snap/mock" +) + +type fakeNext struct { + isCalled bool +} + +func (f *fakeNext) next(w http.ResponseWriter, r *http.Request) { + f.isCalled = true +} + +func TestCAPIAuth(t *testing.T) { + t.Run("NoTokenHeader", func(t *testing.T) { + r := &http.Request{} + fake := &fakeNext{} + fn := middleware.CAPIAuthToken(fake.next, nil) + w := httptest.NewRecorder() + fn(w, r) + + g := NewWithT(t) + + g.Expect(w.Result().StatusCode).To(Equal(http.StatusUnauthorized)) + g.Expect(fake.isCalled).To(BeFalse()) + }) + + t.Run("InvalidToken", func(t *testing.T) { + r := &http.Request{ + Header: http.Header{ + http.CanonicalHeaderKey(middleware.CAPIAuthTokenHeader): []string{"invalid-token"}, + }, + } + fake := &fakeNext{} + snapM := &mock.Snap{ + CAPIAuthTokenValid: false, // explicit + } + fn := middleware.CAPIAuthToken(fake.next, snapM) + w := httptest.NewRecorder() + fn(w, r) + + g := NewWithT(t) + + g.Expect(w.Result().StatusCode).To(Equal(http.StatusUnauthorized)) + g.Expect(fake.isCalled).To(BeFalse()) + }) + + t.Run("FailedToValidate", func(t *testing.T) { + r := &http.Request{ + Header: http.Header{ + http.CanonicalHeaderKey(middleware.CAPIAuthTokenHeader): []string{"invalid-token"}, + }, + } + fake := &fakeNext{} + validateErr := errors.New("failed to validate") + snapM := &mock.Snap{ + CAPIAuthTokenError: validateErr, + } + fn := middleware.CAPIAuthToken(fake.next, snapM) + w := httptest.NewRecorder() + fn(w, r) + + g := NewWithT(t) + + g.Expect(w.Result().StatusCode).To(Equal(http.StatusInternalServerError)) + g.Expect(fake.isCalled).To(BeFalse()) + }) + + t.Run("Success", func(t *testing.T) { + r := &http.Request{ + Header: http.Header{ + http.CanonicalHeaderKey(middleware.CAPIAuthTokenHeader): []string{"valid-token"}, + }, + } + fake := &fakeNext{} + snapM := &mock.Snap{ + CAPIAuthTokenValid: true, + } + fn := middleware.CAPIAuthToken(fake.next, snapM) + w := httptest.NewRecorder() + fn(w, r) + + g := NewWithT(t) + + g.Expect(fake.isCalled).To(BeTrue()) + }) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 9bdfc5e..8a47a02 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -9,6 +9,7 @@ import ( v2 "github.com/canonical/microk8s-cluster-agent/pkg/api/v2" "github.com/canonical/microk8s-cluster-agent/pkg/httputil" "github.com/canonical/microk8s-cluster-agent/pkg/middleware" + "github.com/canonical/microk8s-cluster-agent/pkg/snap" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -21,6 +22,10 @@ func NewServeMux(timeout time.Duration, enableMetrics bool, apiv1 *v1.API, apiv2 return middleware.Log(timeoutMiddleware(f)) } + capiAuthMiddleWare := func(f http.HandlerFunc, snp snap.Snap) http.HandlerFunc { + return middleware.CAPIAuthToken(f, snp) + } + // Default handler server.HandleFunc("/", withMiddleware(func(w http.ResponseWriter, r *http.Request) { httputil.Error(w, http.StatusNotFound, fmt.Errorf("not found")) @@ -43,7 +48,7 @@ func NewServeMux(timeout time.Duration, enableMetrics bool, apiv1 *v1.API, apiv2 // Cluster Agent API apiv1.RegisterServer(server, withMiddleware) - apiv2.RegisterServer(server, withMiddleware) + apiv2.RegisterServer(server, withMiddleware, capiAuthMiddleWare) return server } From 2496acffd5c55b561f3e76378b9dbccd0eefe3bd Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Fri, 11 Oct 2024 13:09:51 +0400 Subject: [PATCH 7/9] Fix lint --- pkg/middleware/capi.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/middleware/capi.go b/pkg/middleware/capi.go index e7555b0..0397b2d 100644 --- a/pkg/middleware/capi.go +++ b/pkg/middleware/capi.go @@ -13,6 +13,7 @@ const ( CAPIAuthTokenHeader = "capi-auth-token" ) +// CAPIAuthToken is a middleware that checks the CAPI auth token. func CAPIAuthToken(next http.HandlerFunc, snap snap.Snap) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get(CAPIAuthTokenHeader) From 6e986affaa14ce8476109932d0cd1df4837c88b2 Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Fri, 11 Oct 2024 13:55:53 +0400 Subject: [PATCH 8/9] Revert "Fix lint" This reverts commit 2496acffd5c55b561f3e76378b9dbccd0eefe3bd. --- pkg/middleware/capi.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/middleware/capi.go b/pkg/middleware/capi.go index 0397b2d..e7555b0 100644 --- a/pkg/middleware/capi.go +++ b/pkg/middleware/capi.go @@ -13,7 +13,6 @@ const ( CAPIAuthTokenHeader = "capi-auth-token" ) -// CAPIAuthToken is a middleware that checks the CAPI auth token. func CAPIAuthToken(next http.HandlerFunc, snap snap.Snap) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get(CAPIAuthTokenHeader) From f4e581919d5fbe6758e416e58a6621422f7ba5c7 Mon Sep 17 00:00:00 2001 From: "Homayoon (Hue) Alimohammadi" Date: Fri, 11 Oct 2024 13:56:05 +0400 Subject: [PATCH 9/9] Revert "Add middleware for capi auth token" This reverts commit 2bcaeca05ce83c25607faae23b53b8fd7baa27ad. --- pkg/api/v2/consts.go | 6 +++ pkg/api/v2/register.go | 41 ++++++++-------- pkg/api/v2/remove.go | 11 ++++- pkg/api/v2/remove_test.go | 40 ++++++++++++++-- pkg/middleware/capi.go | 39 --------------- pkg/middleware/capi_test.go | 96 ------------------------------------- pkg/server/server.go | 7 +-- 7 files changed, 72 insertions(+), 168 deletions(-) create mode 100644 pkg/api/v2/consts.go delete mode 100644 pkg/middleware/capi.go delete mode 100644 pkg/middleware/capi_test.go diff --git a/pkg/api/v2/consts.go b/pkg/api/v2/consts.go new file mode 100644 index 0000000..2fb8a82 --- /dev/null +++ b/pkg/api/v2/consts.go @@ -0,0 +1,6 @@ +package v2 + +const ( + // CAPIAuthTokenHeader is the header used to pass the CAPI auth token. + CAPIAuthTokenHeader = "capi-auth-token" +) diff --git a/pkg/api/v2/register.go b/pkg/api/v2/register.go index 3e676da..115bd46 100644 --- a/pkg/api/v2/register.go +++ b/pkg/api/v2/register.go @@ -5,14 +5,13 @@ import ( "net/http" "github.com/canonical/microk8s-cluster-agent/pkg/httputil" - "github.com/canonical/microk8s-cluster-agent/pkg/snap" ) // HTTPPrefix is the prefix for all v2 API routes. const HTTPPrefix = "/cluster/api/v2.0" // RegisterServer registers the Cluster API v2 endpoints on an HTTP server. -func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.HandlerFunc) http.HandlerFunc, capiAuthMiddleware func(http.HandlerFunc, snap.Snap) http.HandlerFunc) { +func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.HandlerFunc) http.HandlerFunc) { // POST v2/join server.HandleFunc(fmt.Sprintf("%s/join", HTTPPrefix), middleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -56,27 +55,25 @@ func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.Handl })) // POST v2/dqlite/remove - server.HandleFunc(fmt.Sprintf("%s/dqlite/remove", HTTPPrefix), - middleware(func(w http.ResponseWriter, r *http.Request) { - capiAuthMiddleware(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } + server.HandleFunc(fmt.Sprintf("%s/dqlite/remove", HTTPPrefix), middleware(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + req := RemoveFromDqliteRequest{} + if err := httputil.UnmarshalJSON(r, &req); err != nil { + httputil.Error(w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal JSON: %w", err)) + return + } - req := RemoveFromDqliteRequest{} - if err := httputil.UnmarshalJSON(r, &req); err != nil { - httputil.Error(w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal JSON: %w", err)) - return - } + token := r.Header.Get(CAPIAuthTokenHeader) - if rc, err := a.RemoveFromDqlite(r.Context(), req); err != nil { - httputil.Error(w, rc, fmt.Errorf("failed to remove from dqlite: %w", err)) - return - } + if rc, err := a.RemoveFromDqlite(r.Context(), req, token); err != nil { + httputil.Error(w, rc, fmt.Errorf("failed to remove from dqlite: %w", err)) + return + } - httputil.Response(w, nil) - }, a.Snap) - }), - ) + httputil.Response(w, nil) + })) } diff --git a/pkg/api/v2/remove.go b/pkg/api/v2/remove.go index fe6d9fa..0fd00c0 100644 --- a/pkg/api/v2/remove.go +++ b/pkg/api/v2/remove.go @@ -15,7 +15,16 @@ type RemoveFromDqliteRequest struct { } // RemoveFromDqlite implements the "POST /v2/dqlite/remove" endpoint and removes a node from the dqlite cluster. -func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest) (int, error) { +func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest, token string) (int, error) { + isValid, err := a.Snap.IsCAPIAuthTokenValid(token) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to validate CAPI auth token: %w", err) + } + + if !isValid { + return http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token) + } + if err := snaputil.RemoveNodeFromDqlite(ctx, a.Snap, req.RemoveEndpoint); err != nil { return http.StatusInternalServerError, fmt.Errorf("failed to remove node from dqlite: %w", err) } diff --git a/pkg/api/v2/remove_test.go b/pkg/api/v2/remove_test.go index 5ad4aba..2371b4d 100644 --- a/pkg/api/v2/remove_test.go +++ b/pkg/api/v2/remove_test.go @@ -17,23 +17,55 @@ func TestRemove(t *testing.T) { cmdErr := errors.New("failed to run command") apiv2 := &v2.API{ Snap: &mock.Snap{ - RunCommandErr: cmdErr, + RunCommandErr: cmdErr, + CAPIAuthTokenValid: true, }, } - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}) + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") g := NewWithT(t) g.Expect(err).To(MatchError(cmdErr)) g.Expect(rc).To(Equal(http.StatusInternalServerError)) }) + t.Run("InvalidToken", func(t *testing.T) { + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenValid: false, // explicitly set to false + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(HaveOccurred()) + g.Expect(rc).To(Equal(http.StatusUnauthorized)) + }) + + t.Run("TokenFileNotFound", func(t *testing.T) { + tokenErr := errors.New("token file not found") + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenError: tokenErr, + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(MatchError(tokenErr)) + g.Expect(rc).To(Equal(http.StatusInternalServerError)) + }) + t.Run("RemovesSuccessfully", func(t *testing.T) { apiv2 := &v2.API{ - Snap: &mock.Snap{}, + Snap: &mock.Snap{ + CAPIAuthTokenValid: true, + }, } - rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}) + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") g := NewWithT(t) g.Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/middleware/capi.go b/pkg/middleware/capi.go deleted file mode 100644 index e7555b0..0000000 --- a/pkg/middleware/capi.go +++ /dev/null @@ -1,39 +0,0 @@ -package middleware - -import ( - "fmt" - "net/http" - - "github.com/canonical/microk8s-cluster-agent/pkg/httputil" - "github.com/canonical/microk8s-cluster-agent/pkg/snap" -) - -const ( - // CAPIAuthTokenHeader is the header used to pass the CAPI auth token. - CAPIAuthTokenHeader = "capi-auth-token" -) - -func CAPIAuthToken(next http.HandlerFunc, snap snap.Snap) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get(CAPIAuthTokenHeader) - fmt.Println(r.Header, "-->", r.Header.Get(CAPIAuthTokenHeader)) - fmt.Println("token", token) - if token == "" { - httputil.Error(w, http.StatusUnauthorized, fmt.Errorf("missing CAPI auth token")) - return - } - - isValid, err := snap.IsCAPIAuthTokenValid(token) - if err != nil { - httputil.Error(w, http.StatusInternalServerError, fmt.Errorf("failed to validate CAPI auth token: %w", err)) - return - } - - if !isValid { - httputil.Error(w, http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token)) - return - } - - next.ServeHTTP(w, r) - } -} diff --git a/pkg/middleware/capi_test.go b/pkg/middleware/capi_test.go deleted file mode 100644 index b589ce9..0000000 --- a/pkg/middleware/capi_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package middleware_test - -import ( - "errors" - "net/http" - "net/http/httptest" - "testing" - - . "github.com/onsi/gomega" - - "github.com/canonical/microk8s-cluster-agent/pkg/middleware" - "github.com/canonical/microk8s-cluster-agent/pkg/snap/mock" -) - -type fakeNext struct { - isCalled bool -} - -func (f *fakeNext) next(w http.ResponseWriter, r *http.Request) { - f.isCalled = true -} - -func TestCAPIAuth(t *testing.T) { - t.Run("NoTokenHeader", func(t *testing.T) { - r := &http.Request{} - fake := &fakeNext{} - fn := middleware.CAPIAuthToken(fake.next, nil) - w := httptest.NewRecorder() - fn(w, r) - - g := NewWithT(t) - - g.Expect(w.Result().StatusCode).To(Equal(http.StatusUnauthorized)) - g.Expect(fake.isCalled).To(BeFalse()) - }) - - t.Run("InvalidToken", func(t *testing.T) { - r := &http.Request{ - Header: http.Header{ - http.CanonicalHeaderKey(middleware.CAPIAuthTokenHeader): []string{"invalid-token"}, - }, - } - fake := &fakeNext{} - snapM := &mock.Snap{ - CAPIAuthTokenValid: false, // explicit - } - fn := middleware.CAPIAuthToken(fake.next, snapM) - w := httptest.NewRecorder() - fn(w, r) - - g := NewWithT(t) - - g.Expect(w.Result().StatusCode).To(Equal(http.StatusUnauthorized)) - g.Expect(fake.isCalled).To(BeFalse()) - }) - - t.Run("FailedToValidate", func(t *testing.T) { - r := &http.Request{ - Header: http.Header{ - http.CanonicalHeaderKey(middleware.CAPIAuthTokenHeader): []string{"invalid-token"}, - }, - } - fake := &fakeNext{} - validateErr := errors.New("failed to validate") - snapM := &mock.Snap{ - CAPIAuthTokenError: validateErr, - } - fn := middleware.CAPIAuthToken(fake.next, snapM) - w := httptest.NewRecorder() - fn(w, r) - - g := NewWithT(t) - - g.Expect(w.Result().StatusCode).To(Equal(http.StatusInternalServerError)) - g.Expect(fake.isCalled).To(BeFalse()) - }) - - t.Run("Success", func(t *testing.T) { - r := &http.Request{ - Header: http.Header{ - http.CanonicalHeaderKey(middleware.CAPIAuthTokenHeader): []string{"valid-token"}, - }, - } - fake := &fakeNext{} - snapM := &mock.Snap{ - CAPIAuthTokenValid: true, - } - fn := middleware.CAPIAuthToken(fake.next, snapM) - w := httptest.NewRecorder() - fn(w, r) - - g := NewWithT(t) - - g.Expect(fake.isCalled).To(BeTrue()) - }) -} diff --git a/pkg/server/server.go b/pkg/server/server.go index 8a47a02..9bdfc5e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -9,7 +9,6 @@ import ( v2 "github.com/canonical/microk8s-cluster-agent/pkg/api/v2" "github.com/canonical/microk8s-cluster-agent/pkg/httputil" "github.com/canonical/microk8s-cluster-agent/pkg/middleware" - "github.com/canonical/microk8s-cluster-agent/pkg/snap" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -22,10 +21,6 @@ func NewServeMux(timeout time.Duration, enableMetrics bool, apiv1 *v1.API, apiv2 return middleware.Log(timeoutMiddleware(f)) } - capiAuthMiddleWare := func(f http.HandlerFunc, snp snap.Snap) http.HandlerFunc { - return middleware.CAPIAuthToken(f, snp) - } - // Default handler server.HandleFunc("/", withMiddleware(func(w http.ResponseWriter, r *http.Request) { httputil.Error(w, http.StatusNotFound, fmt.Errorf("not found")) @@ -48,7 +43,7 @@ func NewServeMux(timeout time.Duration, enableMetrics bool, apiv1 *v1.API, apiv2 // Cluster Agent API apiv1.RegisterServer(server, withMiddleware) - apiv2.RegisterServer(server, withMiddleware, capiAuthMiddleWare) + apiv2.RegisterServer(server, withMiddleware) return server }