Skip to content

Commit 5c3b2ba

Browse files
committed
fix: unit tests work after cluster provider changes
Signed-off-by: Calum Murray <[email protected]>
1 parent b35cadd commit 5c3b2ba

File tree

4 files changed

+71
-32
lines changed

4 files changed

+71
-32
lines changed

pkg/http/http_test.go

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ func TestHealthCheck(t *testing.T) {
292292
})
293293
})
294294
// Health exposed even when require Authorization
295-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
295+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
296296
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
297297
if err != nil {
298298
t.Fatalf("Failed to get health check endpoint with OAuth: %v", err)
@@ -313,7 +313,7 @@ func TestWellKnownReverseProxy(t *testing.T) {
313313
".well-known/openid-configuration",
314314
}
315315
// With No Authorization URL configured
316-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
316+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
317317
for _, path := range cases {
318318
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
319319
t.Cleanup(func() { _ = resp.Body.Close() })
@@ -333,7 +333,12 @@ func TestWellKnownReverseProxy(t *testing.T) {
333333
_, _ = w.Write([]byte(`NOT A JSON PAYLOAD`))
334334
}))
335335
t.Cleanup(invalidPayloadServer.Close)
336-
invalidPayloadConfig := &config.StaticConfig{AuthorizationURL: invalidPayloadServer.URL, RequireOAuth: true, ValidateToken: true}
336+
invalidPayloadConfig := &config.StaticConfig{
337+
AuthorizationURL: invalidPayloadServer.URL,
338+
RequireOAuth: true,
339+
ValidateToken: true,
340+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
341+
}
337342
testCaseWithContext(t, &httpContext{StaticConfig: invalidPayloadConfig}, func(ctx *httpContext) {
338343
for _, path := range cases {
339344
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
@@ -358,7 +363,12 @@ func TestWellKnownReverseProxy(t *testing.T) {
358363
_, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`))
359364
}))
360365
t.Cleanup(testServer.Close)
361-
staticConfig := &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
366+
staticConfig := &config.StaticConfig{
367+
AuthorizationURL: testServer.URL,
368+
RequireOAuth: true,
369+
ValidateToken: true,
370+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
371+
}
362372
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) {
363373
for _, path := range cases {
364374
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
@@ -401,7 +411,12 @@ func TestWellKnownOverrides(t *testing.T) {
401411
}`))
402412
}))
403413
t.Cleanup(testServer.Close)
404-
baseConfig := config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
414+
baseConfig := config.StaticConfig{
415+
AuthorizationURL: testServer.URL,
416+
RequireOAuth: true,
417+
ValidateToken: true,
418+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
419+
}
405420
// With Dynamic Client Registration disabled
406421
disableDynamicRegistrationConfig := baseConfig
407422
disableDynamicRegistrationConfig.DisableDynamicClientRegistration = true
@@ -488,7 +503,7 @@ func TestMiddlewareLogging(t *testing.T) {
488503

489504
func TestAuthorizationUnauthorized(t *testing.T) {
490505
// Missing Authorization header
491-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
506+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
492507
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
493508
if err != nil {
494509
t.Fatalf("Failed to get protected endpoint: %v", err)
@@ -513,7 +528,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
513528
})
514529
})
515530
// Authorization header without Bearer prefix
516-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
531+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
517532
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
518533
if err != nil {
519534
t.Fatalf("Failed to create request: %v", err)
@@ -538,7 +553,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
538553
})
539554
})
540555
// Invalid Authorization header
541-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
556+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
542557
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
543558
if err != nil {
544559
t.Fatalf("Failed to create request: %v", err)
@@ -569,7 +584,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
569584
})
570585
})
571586
// Expired Authorization Bearer token
572-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
587+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
573588
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
574589
if err != nil {
575590
t.Fatalf("Failed to create request: %v", err)
@@ -600,7 +615,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
600615
})
601616
})
602617
// Invalid audience claim Bearer token
603-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true}}, func(ctx *httpContext) {
618+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
604619
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
605620
if err != nil {
606621
t.Fatalf("Failed to create request: %v", err)
@@ -633,7 +648,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
633648
// Failed OIDC validation
634649
oidcTestServer := NewOidcTestServer(t)
635650
t.Cleanup(oidcTestServer.Close)
636-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
651+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
637652
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
638653
if err != nil {
639654
t.Fatalf("Failed to create request: %v", err)
@@ -670,7 +685,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
670685
"aud": "mcp-server"
671686
}`
672687
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
673-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
688+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
674689
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
675690
if err != nil {
676691
t.Fatalf("Failed to create request: %v", err)
@@ -703,7 +718,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
703718
}
704719

705720
func TestAuthorizationRequireOAuthFalse(t *testing.T) {
706-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
721+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
707722
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
708723
if err != nil {
709724
t.Fatalf("Failed to get protected endpoint: %v", err)
@@ -728,7 +743,7 @@ func TestAuthorizationRawToken(t *testing.T) {
728743
{"mcp-server", true}, // Audience set, validation enabled
729744
}
730745
for _, c := range cases {
731-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken}}, func(ctx *httpContext) {
746+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
732747
tokenReviewed := false
733748
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
734749
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
@@ -777,7 +792,7 @@ func TestAuthorizationOidcToken(t *testing.T) {
777792
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
778793
cases := []bool{false, true}
779794
for _, validateToken := range cases {
780-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
795+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
781796
tokenReviewed := false
782797
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
783798
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
@@ -833,13 +848,14 @@ func TestAuthorizationOidcTokenExchange(t *testing.T) {
833848
cases := []bool{false, true}
834849
for _, validateToken := range cases {
835850
staticConfig := &config.StaticConfig{
836-
RequireOAuth: true,
837-
OAuthAudience: "mcp-server",
838-
ValidateToken: validateToken,
839-
StsClientId: "test-sts-client-id",
840-
StsClientSecret: "test-sts-client-secret",
841-
StsAudience: "backend-audience",
842-
StsScopes: []string{"backend-scope"},
851+
RequireOAuth: true,
852+
OAuthAudience: "mcp-server",
853+
ValidateToken: validateToken,
854+
StsClientId: "test-sts-client-id",
855+
StsClientSecret: "test-sts-client-secret",
856+
StsAudience: "backend-audience",
857+
StsScopes: []string{"backend-scope"},
858+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
843859
}
844860
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
845861
tokenReviewed := false

pkg/kubernetes/cluster.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,36 @@ func newKubeConfigClusterProvider(config *config.StaticConfig) (*kubeConfigClust
4242
return nil, err
4343
}
4444

45+
// Handle in-cluster mode
46+
if m.IsInCluster() {
47+
return &kubeConfigClusterProvider{
48+
defaultCluster: "default",
49+
managers: map[string]*Manager{"default": m},
50+
}, nil
51+
}
52+
4553
rawConfig, err := m.clientCmdConfig.RawConfig()
4654
if err != nil {
4755
return nil, err
4856
}
4957

50-
defaultContext := rawConfig.Contexts[rawConfig.CurrentContext]
58+
defaultContext, ok := rawConfig.Contexts[rawConfig.CurrentContext]
59+
if !ok || defaultContext == nil {
60+
return nil, fmt.Errorf("current context '%s' not found in kubeconfig", rawConfig.CurrentContext)
61+
}
5162

52-
allClusterManagers := make(map[string]*Manager)
63+
allClusterManagers := map[string]*Manager{
64+
defaultContext.Cluster: m, // we already initialized a manager for the default context, let's use it
65+
}
5366

5467
for _, context := range rawConfig.Contexts {
5568
if _, exists := rawConfig.Clusters[context.Cluster]; exists {
56-
allClusterManagers[context.Cluster] = nil // these will be lazy initialized as they are accessed later
69+
if _, alreadyExists := allClusterManagers[context.Cluster]; !alreadyExists {
70+
allClusterManagers[context.Cluster] = nil // these will be lazy initialized as they are accessed later
71+
}
5772
}
5873
}
5974

60-
// we already initialized a manager for the default context, let's use it
61-
allClusterManagers[defaultContext.Cluster] = m
62-
6375
return &kubeConfigClusterProvider{
6476
defaultCluster: defaultContext.Cluster,
6577
managers: allClusterManagers,

pkg/mcp/common_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func (c *mcpContext) withKubeConfig(rc *rest.Config) *clientcmdapi.Config {
219219
_ = clientcmd.WriteToFile(*fakeConfig, kubeConfig)
220220
_ = os.Setenv("KUBECONFIG", kubeConfig)
221221
if c.mcpServer != nil {
222-
if err := c.mcpServer.reloadKubernetesClient(); err != nil {
222+
if err := c.mcpServer.reloadKubernetesClusterProvider(); err != nil {
223223
panic(err)
224224
}
225225
}

pkg/mcp/mcp_tools_test.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ func TestUnrestricted(t *testing.T) {
3434
}
3535

3636
func TestReadOnly(t *testing.T) {
37-
readOnlyServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{ReadOnly: true} }
37+
readOnlyServer := func(c *mcpContext) {
38+
c.staticConfig = &config.StaticConfig{
39+
ReadOnly: true,
40+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
41+
}
42+
}
3843
testCaseWithContext(t, &mcpContext{before: readOnlyServer}, func(c *mcpContext) {
3944
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
4045
t.Run("ListTools returns tools", func(t *testing.T) {
@@ -56,7 +61,12 @@ func TestReadOnly(t *testing.T) {
5661
}
5762

5863
func TestDisableDestructive(t *testing.T) {
59-
disableDestructiveServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{DisableDestructive: true} }
64+
disableDestructiveServer := func(c *mcpContext) {
65+
c.staticConfig = &config.StaticConfig{
66+
DisableDestructive: true,
67+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
68+
}
69+
}
6070
testCaseWithContext(t, &mcpContext{before: disableDestructiveServer}, func(c *mcpContext) {
6171
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
6272
t.Run("ListTools returns tools", func(t *testing.T) {
@@ -101,7 +111,8 @@ func TestEnabledTools(t *testing.T) {
101111
func TestDisabledTools(t *testing.T) {
102112
testCaseWithContext(t, &mcpContext{
103113
staticConfig: &config.StaticConfig{
104-
DisabledTools: []string{"namespaces_list", "events_list"},
114+
DisabledTools: []string{"namespaces_list", "events_list"},
115+
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
105116
},
106117
}, func(c *mcpContext) {
107118
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})

0 commit comments

Comments
 (0)