Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions internal/test/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@ func KubeConfigFake() *clientcmdapi.Config {
fakeConfig := clientcmdapi.NewConfig()
fakeConfig.Clusters["fake"] = clientcmdapi.NewCluster()
fakeConfig.Clusters["fake"].Server = "https://127.0.0.1:6443"
fakeConfig.Clusters["additional-cluster"] = clientcmdapi.NewCluster()
fakeConfig.AuthInfos["fake"] = clientcmdapi.NewAuthInfo()
fakeConfig.AuthInfos["additional-auth"] = clientcmdapi.NewAuthInfo()
fakeConfig.Contexts["fake-context"] = clientcmdapi.NewContext()
fakeConfig.Contexts["fake-context"].Cluster = "fake"
fakeConfig.Contexts["fake-context"].AuthInfo = "fake"
fakeConfig.Contexts["additional-context"] = clientcmdapi.NewContext()
fakeConfig.Contexts["additional-context"].Cluster = "additional-cluster"
fakeConfig.Contexts["additional-context"].AuthInfo = "additional-auth"
fakeConfig.CurrentContext = "fake-context"
return fakeConfig
}
10 changes: 7 additions & 3 deletions internal/test/mock_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ func (m *MockServer) Kubeconfig() *api.Config {
}

func (m *MockServer) KubeconfigFile(t *testing.T) string {
kubeconfig := filepath.Join(t.TempDir(), "config")
err := clientcmd.WriteToFile(*m.Kubeconfig(), kubeconfig)
return KubeconfigFile(t, m.Kubeconfig())
}

func KubeconfigFile(t *testing.T, kubeconfig *api.Config) string {
kubeconfigFile := filepath.Join(t.TempDir(), "config")
err := clientcmd.WriteToFile(*kubeconfig, kubeconfigFile)
require.NoError(t, err, "Expected no error writing kubeconfig file")
return kubeconfig
return kubeconfigFile
}

func WriteObject(w http.ResponseWriter, obj runtime.Object) {
Expand Down
25 changes: 23 additions & 2 deletions pkg/api/toolsets.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,29 @@ import (
)

type ServerTool struct {
Tool Tool
Handler ToolHandlerFunc
Tool Tool
Handler ToolHandlerFunc
ClusterAware *bool
TargetListProvider *bool
}

// IsClusterAware indicates whether the tool can accept a "cluster" or "context" parameter
// to operate on a specific Kubernetes cluster context.
// Defaults to true if not explicitly set
func (s *ServerTool) IsClusterAware() bool {
if s.ClusterAware != nil {
return *s.ClusterAware
}
return true
}

// IsTargetListProvider indicates whether the tool is used to provide a list of targets (clusters/contexts)
// Defaults to false if not explicitly set
func (s *ServerTool) IsTargetListProvider() bool {
if s.TargetListProvider != nil {
return *s.TargetListProvider
}
return false
}

type Toolset interface {
Expand Down
47 changes: 47 additions & 0 deletions pkg/api/toolsets_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package api

import (
"testing"

"github.com/stretchr/testify/suite"
"k8s.io/utils/ptr"
)

type ToolsetsSuite struct {
suite.Suite
}

func (s *ToolsetsSuite) TestServerTool() {
s.Run("IsClusterAware", func() {
s.Run("defaults to true", func() {
tool := &ServerTool{}
s.True(tool.IsClusterAware(), "Expected IsClusterAware to be true by default")
})
s.Run("can be set to false", func() {
tool := &ServerTool{ClusterAware: ptr.To(false)}
s.False(tool.IsClusterAware(), "Expected IsClusterAware to be false when set to false")
})
s.Run("can be set to true", func() {
tool := &ServerTool{ClusterAware: ptr.To(true)}
s.True(tool.IsClusterAware(), "Expected IsClusterAware to be true when set to true")
})
})
s.Run("IsTargetListProvider", func() {
s.Run("defaults to false", func() {
tool := &ServerTool{}
s.False(tool.IsTargetListProvider(), "Expected IsTargetListProvider to be false by default")
})
s.Run("can be set to false", func() {
tool := &ServerTool{TargetListProvider: ptr.To(false)}
s.False(tool.IsTargetListProvider(), "Expected IsTargetListProvider to be false when set to false")
})
s.Run("can be set to true", func() {
tool := &ServerTool{TargetListProvider: ptr.To(true)}
s.True(tool.IsTargetListProvider(), "Expected IsTargetListProvider to be true when set to true")
})
})
}

func TestToolsets(t *testing.T) {
suite.Run(t, new(ToolsetsSuite))
}
11 changes: 11 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import (
"github.com/BurntSushi/toml"
)

const (
ClusterProviderKubeConfig = "kubeconfig"
ClusterProviderInCluster = "in-cluster"
)

// StaticConfig is the configuration for the server.
// It allows to configure server specific settings and tools to be enabled or disabled.
type StaticConfig struct {
Expand Down Expand Up @@ -49,6 +54,12 @@ type StaticConfig struct {
StsScopes []string `toml:"sts_scopes,omitempty"`
CertificateAuthority string `toml:"certificate_authority,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
// ClusterProviderStrategy is how the server finds clusters.
// If set to "kubeconfig", the clusters will be loaded from those in the kubeconfig.
// If set to "in-cluster", the server will use the in cluster config
ClusterProviderStrategy string `toml:"cluster_provider_strategy,omitempty"`
// ClusterContexts is which context should be used for each cluster
ClusterContexts map[string]string `toml:"cluster_contexts"`
}

func Default() *StaticConfig {
Expand Down
53 changes: 49 additions & 4 deletions pkg/http/authorization.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package http

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

Expand All @@ -20,7 +23,44 @@ import (

type KubernetesApiTokenVerifier interface {
// KubernetesApiVerifyToken TODO: clarify proper implementation
KubernetesApiVerifyToken(ctx context.Context, token, audience string) (*authenticationapiv1.UserInfo, []string, error)
KubernetesApiVerifyToken(ctx context.Context, token, audience, cluster string) (*authenticationapiv1.UserInfo, []string, error)
// GetTargetParameterName returns the parameter name used for target identification in MCP requests
GetTargetParameterName() string
}

// extractTargetFromRequest extracts cluster parameter from MCP request body
func extractTargetFromRequest(r *http.Request, targetName string) (string, error) {
if r.Body == nil {
return "", nil
}

// Read the body
body, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}

// Restore the body for downstream handlers
r.Body = io.NopCloser(bytes.NewBuffer(body))

// Parse the MCP request
var mcpRequest struct {
Params struct {
Arguments map[string]interface{} `json:"arguments"`
} `json:"params"`
}

if err := json.Unmarshal(body, &mcpRequest); err != nil {
// If we can't parse the request, just return empty cluster (will use default)
return "", nil
}

// Extract target parameter
if cluster, ok := mcpRequest.Params.Arguments[targetName].(string); ok {
return cluster, nil
}

return "", nil
}

// AuthorizationMiddleware validates the OAuth flow for protected resources.
Expand Down Expand Up @@ -128,7 +168,12 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
}
// Kubernetes API Server TokenReview validation
if err == nil && staticConfig.ValidateToken {
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, verifier)
targetParameterName := verifier.GetTargetParameterName()
cluster, clusterErr := extractTargetFromRequest(r, targetParameterName)
if clusterErr != nil {
klog.V(2).Infof("Failed to extract cluster from request, using default: %v", clusterErr)
}
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, cluster, verifier)
}
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
Expand Down Expand Up @@ -198,9 +243,9 @@ func (c *JWTClaims) ValidateWithProvider(ctx context.Context, audience string, p
return nil
}

func (c *JWTClaims) ValidateWithKubernetesApi(ctx context.Context, audience string, verifier KubernetesApiTokenVerifier) error {
func (c *JWTClaims) ValidateWithKubernetesApi(ctx context.Context, audience, cluster string, verifier KubernetesApiTokenVerifier) error {
if verifier != nil {
_, _, err := verifier.KubernetesApiVerifyToken(ctx, c.Token, audience)
_, _, err := verifier.KubernetesApiVerifyToken(ctx, c.Token, audience, cluster)
if err != nil {
return fmt.Errorf("kubernetes API token validation error: %v", err)
}
Expand Down
60 changes: 38 additions & 22 deletions pkg/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func TestHealthCheck(t *testing.T) {
})
})
// Health exposed even when require Authorization
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get health check endpoint with OAuth: %v", err)
Expand All @@ -313,7 +313,7 @@ func TestWellKnownReverseProxy(t *testing.T) {
".well-known/openid-configuration",
}
// With No Authorization URL configured
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
for _, path := range cases {
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
t.Cleanup(func() { _ = resp.Body.Close() })
Expand All @@ -333,7 +333,12 @@ func TestWellKnownReverseProxy(t *testing.T) {
_, _ = w.Write([]byte(`NOT A JSON PAYLOAD`))
}))
t.Cleanup(invalidPayloadServer.Close)
invalidPayloadConfig := &config.StaticConfig{AuthorizationURL: invalidPayloadServer.URL, RequireOAuth: true, ValidateToken: true}
invalidPayloadConfig := &config.StaticConfig{
AuthorizationURL: invalidPayloadServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: invalidPayloadConfig}, func(ctx *httpContext) {
for _, path := range cases {
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
Expand All @@ -358,7 +363,12 @@ func TestWellKnownReverseProxy(t *testing.T) {
_, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`))
}))
t.Cleanup(testServer.Close)
staticConfig := &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
staticConfig := &config.StaticConfig{
AuthorizationURL: testServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) {
for _, path := range cases {
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
Expand Down Expand Up @@ -401,7 +411,12 @@ func TestWellKnownOverrides(t *testing.T) {
}`))
}))
t.Cleanup(testServer.Close)
baseConfig := config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
baseConfig := config.StaticConfig{
AuthorizationURL: testServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
// With Dynamic Client Registration disabled
disableDynamicRegistrationConfig := baseConfig
disableDynamicRegistrationConfig.DisableDynamicClientRegistration = true
Expand Down Expand Up @@ -488,7 +503,7 @@ func TestMiddlewareLogging(t *testing.T) {

func TestAuthorizationUnauthorized(t *testing.T) {
// Missing Authorization header
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
Expand All @@ -513,7 +528,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Authorization header without Bearer prefix
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand All @@ -538,7 +553,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Invalid Authorization header
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -569,7 +584,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Expired Authorization Bearer token
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -600,7 +615,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Invalid audience claim Bearer token
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -633,7 +648,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
// Failed OIDC validation
oidcTestServer := NewOidcTestServer(t)
t.Cleanup(oidcTestServer.Close)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -670,7 +685,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
Expand Down Expand Up @@ -703,7 +718,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
}

func TestAuthorizationRequireOAuthFalse(t *testing.T) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
Expand All @@ -728,7 +743,7 @@ func TestAuthorizationRawToken(t *testing.T) {
{"mcp-server", true}, // Audience set, validation enabled
}
for _, c := range cases {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken}}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}}, func(ctx *httpContext) {
tokenReviewed := false
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
Expand Down Expand Up @@ -777,7 +792,7 @@ func TestAuthorizationOidcToken(t *testing.T) {
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
cases := []bool{false, true}
for _, validateToken := range cases {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken, ClusterProviderStrategy: config.ClusterProviderKubeConfig}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
tokenReviewed := false
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
Expand Down Expand Up @@ -833,13 +848,14 @@ func TestAuthorizationOidcTokenExchange(t *testing.T) {
cases := []bool{false, true}
for _, validateToken := range cases {
staticConfig := &config.StaticConfig{
RequireOAuth: true,
OAuthAudience: "mcp-server",
ValidateToken: validateToken,
StsClientId: "test-sts-client-id",
StsClientSecret: "test-sts-client-secret",
StsAudience: "backend-audience",
StsScopes: []string{"backend-scope"},
RequireOAuth: true,
OAuthAudience: "mcp-server",
ValidateToken: validateToken,
StsClientId: "test-sts-client-id",
StsClientSecret: "test-sts-client-secret",
StsAudience: "backend-audience",
StsScopes: []string{"backend-scope"},
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
tokenReviewed := false
Expand Down
Loading
Loading