diff --git a/cloud/linode/client/client.go b/cloud/linode/client/client.go index 639bdcfe..ee32d60e 100644 --- a/cloud/linode/client/client.go +++ b/cloud/linode/client/client.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "os" + "strings" "time" "github.com/linode/linodego" @@ -52,6 +53,8 @@ type Client interface { DeleteFirewall(ctx context.Context, fwid int) error GetFirewall(context.Context, int) (*linodego.Firewall, error) UpdateFirewallRules(context.Context, int, linodego.FirewallRuleSet) (*linodego.FirewallRuleSet, error) + + GetProfile(ctx context.Context) (*linodego.Profile, error) } // linodego.Client implements Client @@ -73,3 +76,16 @@ func New(token string, timeout time.Duration) (*linodego.Client, error) { klog.V(3).Infof("Linode client created with default timeout of %v", timeout) return client, nil } + +func CheckClientAuthenticated(ctx context.Context, client Client) (bool, error) { + _, err := client.GetProfile(ctx) + if err == nil { + return true, nil + } + + if strings.Contains(err.Error(), "Invalid Token") { + return false, nil + } + + return false, err +} diff --git a/cloud/linode/client/mocks/mock_client.go b/cloud/linode/client/mocks/mock_client.go index f8baf6b2..c986aef2 100644 --- a/cloud/linode/client/mocks/mock_client.go +++ b/cloud/linode/client/mocks/mock_client.go @@ -255,6 +255,21 @@ func (mr *MockClientMockRecorder) GetNodeBalancer(arg0, arg1 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeBalancer", reflect.TypeOf((*MockClient)(nil).GetNodeBalancer), arg0, arg1) } +// GetProfile mocks base method. +func (m *MockClient) GetProfile(arg0 context.Context) (*linodego.Profile, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProfile", arg0) + ret0, _ := ret[0].(*linodego.Profile) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProfile indicates an expected call of GetProfile. +func (mr *MockClientMockRecorder) GetProfile(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProfile", reflect.TypeOf((*MockClient)(nil).GetProfile), arg0) +} + // ListFirewallDevices mocks base method. func (m *MockClient) ListFirewallDevices(arg0 context.Context, arg1 int, arg2 *linodego.ListOptions) ([]linodego.FirewallDevice, error) { m.ctrl.T.Helper() diff --git a/cloud/linode/cloud.go b/cloud/linode/cloud.go index ed9b21ed..c2a8ae31 100644 --- a/cloud/linode/cloud.go +++ b/cloud/linode/cloud.go @@ -1,6 +1,7 @@ package linode import ( + "context" "fmt" "io" "net" @@ -32,9 +33,10 @@ var supportedLoadBalancerTypes = []string{ciliumLBType, nodeBalancerLBType} // We expect it to be initialized with flags external to this package, likely in // main.go var Options struct { - KubeconfigFlag *pflag.Flag - LinodeGoDebug bool - EnableRouteController bool + KubeconfigFlag *pflag.Flag + LinodeGoDebug bool + EnableRouteController bool + EnableTokenHealthChecker bool // Deprecated: use VPCNames instead VPCName string VPCNames string @@ -43,13 +45,15 @@ var Options struct { IpHolderSuffix string LinodeExternalNetwork *net.IPNet NodeBalancerTags []string + GlobalStopChannel chan<- struct{} } type linodeCloud struct { - client client.Client - instances cloudprovider.InstancesV2 - loadbalancers cloudprovider.LoadBalancer - routes cloudprovider.Routes + client client.Client + instances cloudprovider.InstancesV2 + loadbalancers cloudprovider.LoadBalancer + routes cloudprovider.Routes + linodeTokenHealthChecker *healthChecker } var instanceCache *instances @@ -91,6 +95,24 @@ func newCloud() (cloudprovider.Interface, error) { linodeClient.SetDebug(true) } + var healthChecker *healthChecker + + if Options.EnableTokenHealthChecker { + authenticated, err := client.CheckClientAuthenticated(context.TODO(), linodeClient) + if err != nil { + return nil, fmt.Errorf("linode client authenticated connection error: %w", err) + } + + if !authenticated { + return nil, fmt.Errorf("linode api token '%s' is invalid", accessTokenEnv) + } + + healthChecker, err = newHealthChecker(apiToken, timeout, time.Minute, Options.GlobalStopChannel) + if err != nil { + return nil, fmt.Errorf("unable to initialize healthchecker: %w", err) + } + } + if Options.VPCName != "" && Options.VPCNames != "" { return nil, fmt.Errorf("cannot have both vpc-name and vpc-names set") } @@ -126,10 +148,11 @@ func newCloud() (cloudprovider.Interface, error) { // create struct that satisfies cloudprovider.Interface lcloud := &linodeCloud{ - client: linodeClient, - instances: instanceCache, - loadbalancers: newLoadbalancers(linodeClient, region), - routes: routes, + client: linodeClient, + instances: instanceCache, + loadbalancers: newLoadbalancers(linodeClient, region), + routes: routes, + linodeTokenHealthChecker: healthChecker, } return lcloud, nil } @@ -140,6 +163,10 @@ func (c *linodeCloud) Initialize(clientBuilder cloudprovider.ControllerClientBui serviceInformer := sharedInformer.Core().V1().Services() nodeInformer := sharedInformer.Core().V1().Nodes() + if c.linodeTokenHealthChecker != nil { + go c.linodeTokenHealthChecker.Run(stopCh) + } + serviceController := newServiceController(c.loadbalancers.(*loadbalancers), serviceInformer) go serviceController.Run(stopCh) diff --git a/cloud/linode/fake_linode_test.go b/cloud/linode/fake_linode_test.go index aeb069d8..aea92ce1 100644 --- a/cloud/linode/fake_linode_test.go +++ b/cloud/linode/fake_linode_test.go @@ -25,6 +25,7 @@ type fakeAPI struct { nbn map[string]*linodego.NodeBalancerNode fw map[int]*linodego.Firewall // map of firewallID -> firewall fwd map[int]map[int]*linodego.FirewallDevice // map of firewallID -> firewallDeviceID:FirewallDevice + tkn string requests map[fakeRequest]struct{} mux *http.ServeMux @@ -674,6 +675,29 @@ func (f *fakeAPI) setupRoutes() { rr, _ := json.Marshal(resp) _, _ = w.Write(rr) }) + + f.mux.HandleFunc("GET /v4/profile", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+f.tkn { + errors := make([]linodego.APIErrorReason, 1) + errors[0] = linodego.APIErrorReason{Reason: "Invalid Token"} + resp := linodego.APIError{Errors: errors} + + w.WriteHeader(401) + rr, _ := json.Marshal(resp) + _, _ = w.Write(rr) + return + } + + resp := linodego.Profile{ + UID: 0, + Username: "foo", + Email: "fake@example.com", + } + + w.WriteHeader(200) + rr, _ := json.Marshal(resp) + _, _ = w.Write(rr) + }) } func (f *fakeAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/cloud/linode/health_check.go b/cloud/linode/health_check.go new file mode 100644 index 00000000..922a046f --- /dev/null +++ b/cloud/linode/health_check.go @@ -0,0 +1,63 @@ +package linode + +import ( + "context" + "time" + + "github.com/linode/linode-cloud-controller-manager/cloud/linode/client" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/klog/v2" +) + +type healthChecker struct { + period time.Duration + linodeClient client.Client + stopCh chan<- struct{} +} + +func newHealthChecker(apiToken string, timeout time.Duration, period time.Duration, stopCh chan<- struct{}) (*healthChecker, error) { + client, err := client.New(apiToken, timeout) + if err != nil { + return nil, err + } + + return &healthChecker{ + period: period, + linodeClient: client, + stopCh: stopCh, + }, nil +} + +func (r *healthChecker) Run(stopCh <-chan struct{}) { + ctx := wait.ContextForChannel(stopCh) + wait.Until(r.worker(ctx), r.period, stopCh) +} + +func (r *healthChecker) worker(ctx context.Context) func() { + return func() { + r.do(ctx) + } +} + +func (r *healthChecker) do(ctx context.Context) { + if r.stopCh == nil { + klog.Errorf("stop signal already fired. nothing to do") + return + } + + authenticated, err := client.CheckClientAuthenticated(ctx, r.linodeClient) + if err != nil { + klog.Warningf("unable to determine linode client authentication status: %s", err.Error()) + return + } + + if !authenticated { + klog.Error("detected invalid linode api token: stopping controllers") + + close(r.stopCh) + r.stopCh = nil + return + } + + klog.Info("linode api token is healthy") +} diff --git a/cloud/linode/health_check_test.go b/cloud/linode/health_check_test.go new file mode 100644 index 00000000..fa0a1519 --- /dev/null +++ b/cloud/linode/health_check_test.go @@ -0,0 +1,124 @@ +package linode + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/linode/linodego" +) + +func TestHealthCheck(t *testing.T) { + testCases := []struct { + name string + f func(*testing.T, *linodego.Client, *fakeAPI) + }{ + { + name: "Test succeeding calls to linode api stop signal is not fired", + f: testSucceedingCallsToLinodeAPIHappenStopSignalNotFired, + }, + { + name: "Test failing calls to linode api stop signal is fired", + f: testFailingCallsToLinodeAPIHappenStopSignalFired, + }, + } + + for _, tc := range testCases { + fake := newFake(t) + ts := httptest.NewServer(fake) + + linodeClient := linodego.NewClient(http.DefaultClient) + linodeClient.SetBaseURL(ts.URL) + + t.Run(tc.name, func(t *testing.T) { + defer ts.Close() + tc.f(t, &linodeClient, fake) + }) + } +} + +func testSucceedingCallsToLinodeAPIHappenStopSignalNotFired(t *testing.T, client *linodego.Client, api *fakeAPI) { + writableStopCh := make(chan struct{}) + readableStopCh := make(chan struct{}) + const validToken = "validtoken" + + api.tkn = validToken + client.SetToken(validToken) + + hc, err := newHealthChecker(validToken, 1*time.Second, 1*time.Second, writableStopCh) + if err != nil { + t.Fatalf("expected a nil error, got %v", err) + } + // inject modified linodego.Client + hc.linodeClient = client + + go hc.Run(readableStopCh) + + // wait for check to happen + time.Sleep(2 * time.Second) + + // stop healthChecker goroutine + close(readableStopCh) + + if !api.didRequestOccur(http.MethodGet, "/profile", "") { + t.Error("unexpected linode api calls") + t.Logf("expected: %v /profile", http.MethodGet) + t.Logf("actual: %v", api.requests) + } + + select { + case <-writableStopCh: + t.Error("healthChecker sent stop signal") + default: + } +} + +func testFailingCallsToLinodeAPIHappenStopSignalFired(t *testing.T, client *linodego.Client, api *fakeAPI) { + writableStopCh := make(chan struct{}) + readableStopCh := make(chan struct{}) + const validToken = "validtoken" + const invalidToken = "invalidtoken" + + api.tkn = validToken + client.SetToken(validToken) + + hc, err := newHealthChecker(validToken, 1*time.Second, 1*time.Second, writableStopCh) + if err != nil { + t.Fatalf("expected a nil error, got %v", err) + } + // inject modified linodego.Client + hc.linodeClient = client + + go hc.Run(readableStopCh) + + // wait for check to happen + time.Sleep(2 * time.Second) + + if !api.didRequestOccur(http.MethodGet, "/profile", "") { + t.Error("unexpected linode api calls") + t.Logf("expected: %v /profile", http.MethodGet) + t.Logf("actual: %v", api.requests) + } + + select { + case <-writableStopCh: + t.Error("healthChecker sent stop signal") + default: + } + + // invalidate token + api.tkn = invalidToken + + // wait for check to happen + time.Sleep(2 * time.Second) + + select { + case <-writableStopCh: + default: + t.Error("healthChecker did not send stop signal") + } + + // stop healthChecker goroutine + close(readableStopCh) +} diff --git a/main.go b/main.go index 593755c8..8c52cf0f 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,6 @@ import ( "github.com/linode/linode-cloud-controller-manager/cloud/linode" "github.com/linode/linode-cloud-controller-manager/sentry" "github.com/spf13/pflag" - "k8s.io/apimachinery/pkg/util/wait" cloudprovider "k8s.io/cloud-provider" "k8s.io/cloud-provider/app" "k8s.io/cloud-provider/app/config" @@ -76,11 +75,13 @@ func main() { } fss := utilflag.NamedFlagSets{} controllerAliases := names.CCMControllerAliases() - command := app.NewCloudControllerManagerCommand(ccmOptions, cloudInitializer, app.DefaultInitFuncConstructors, controllerAliases, fss, wait.NeverStop) + stopCh := make(chan struct{}) + command := app.NewCloudControllerManagerCommand(ccmOptions, cloudInitializer, app.DefaultInitFuncConstructors, controllerAliases, fss, stopCh) // Add Linode-specific flags command.Flags().BoolVar(&linode.Options.LinodeGoDebug, "linodego-debug", false, "enables debug output for the LinodeAPI wrapper") command.Flags().BoolVar(&linode.Options.EnableRouteController, "enable-route-controller", false, "enables route_controller for ccm") + command.Flags().BoolVar(&linode.Options.EnableTokenHealthChecker, "enable-token-health-checker", false, "enables linode api token health checker") command.Flags().StringVar(&linode.Options.VPCName, "vpc-name", "", "[deprecated: use vpc-names instead] vpc name whose routes will be managed by route-controller") command.Flags().StringVar(&linode.Options.VPCNames, "vpc-names", "", "comma separated vpc names whose routes will be managed by route-controller") command.Flags().StringVar(&linode.Options.LoadBalancerType, "load-balancer-type", "nodebalancer", "configures which type of load-balancing to use for LoadBalancer Services (options: nodebalancer, cilium-bgp)") @@ -130,6 +131,9 @@ func main() { linode.Options.LinodeExternalNetwork = network } + // Provide stop channel for linode authenticated client healthchecker + linode.Options.GlobalStopChannel = stopCh + pflag.CommandLine.SetNormalizeFunc(utilflag.WordSepNormalizeFunc) pflag.CommandLine.AddGoFlagSet(flag.CommandLine)