diff --git a/internal/test/mcp.go b/internal/test/mcp.go index b82e3194..5fa0d0a4 100644 --- a/internal/test/mcp.go +++ b/internal/test/mcp.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" "golang.org/x/net/context" @@ -17,12 +18,12 @@ type McpClient struct { *client.Client } -func NewMcpClient(t *testing.T, mcpHttpServer http.Handler) *McpClient { +func NewMcpClient(t *testing.T, mcpHttpServer http.Handler, options ...transport.StreamableHTTPCOption) *McpClient { require.NotNil(t, mcpHttpServer, "McpHttpServer must be provided") var err error ret := &McpClient{ctx: t.Context()} ret.testServer = httptest.NewServer(mcpHttpServer) - ret.Client, err = client.NewStreamableHttpClient(ret.testServer.URL + "/mcp") + ret.Client, err = client.NewStreamableHttpClient(ret.testServer.URL+"/mcp", options...) require.NoError(t, err, "Expected no error creating MCP client") err = ret.Start(t.Context()) require.NoError(t, err, "Expected no error starting MCP client") diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index e9c49758..86f2e8d6 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -443,9 +443,9 @@ func (s *BaseMcpSuite) TearDownTest() { } } -func (s *BaseMcpSuite) InitMcpClient() { +func (s *BaseMcpSuite) InitMcpClient(options ...transport.StreamableHTTPCOption) { var err error s.mcpServer, err = NewServer(Configuration{StaticConfig: s.Cfg}) s.Require().NoError(err, "Expected no error creating MCP server") - s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP(nil)) + s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP(nil), options...) } diff --git a/pkg/mcp/mcp_test.go b/pkg/mcp/mcp_test.go index 7be9a423..484d8b59 100644 --- a/pkg/mcp/mcp_test.go +++ b/pkg/mcp/mcp_test.go @@ -10,8 +10,9 @@ import ( "time" "github.com/containers/kubernetes-mcp-server/internal/test" - "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/suite" ) func TestWatchKubeConfig(t *testing.T) { @@ -48,16 +49,19 @@ func TestWatchKubeConfig(t *testing.T) { }) } -func TestSseHeaders(t *testing.T) { - mockServer := test.NewMockServer() - defer mockServer.Close() - before := func(c *mcpContext) { - c.withKubeConfig(mockServer.Config()) - c.clientOptions = append(c.clientOptions, client.WithHeaders(map[string]string{"kubernetes-authorization": "Bearer a-token-from-mcp-client"})) - } - pathHeaders := make(map[string]http.Header, 0) - mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - pathHeaders[req.URL.Path] = req.Header.Clone() +type McpHeadersSuite struct { + BaseMcpSuite + mockServer *test.MockServer + pathHeaders map[string]http.Header +} + +func (s *McpHeadersSuite) SetupTest() { + s.BaseMcpSuite.SetupTest() + s.mockServer = test.NewMockServer() + s.Cfg.KubeConfig = s.mockServer.KubeconfigFile(s.T()) + s.pathHeaders = make(map[string]http.Header) + s.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + s.pathHeaders[req.URL.Path] = req.Header.Clone() // Request Performed by DiscoveryClient to Kube API (Get API Groups legacy -core-) if req.URL.Path == "/api" { w.Header().Set("Content-Type", "application/json") @@ -90,38 +94,42 @@ func TestSseHeaders(t *testing.T) { } w.WriteHeader(404) })) - testCaseWithContext(t, &mcpContext{before: before}, func(c *mcpContext) { - _, _ = c.callTool("pods_list", map[string]interface{}{}) - t.Run("DiscoveryClient propagates headers to Kube API", func(t *testing.T) { - if len(pathHeaders) == 0 { - t.Fatalf("No requests were made to Kube API") - } - if pathHeaders["/api"] == nil || pathHeaders["/api"].Get("Authorization") != "Bearer a-token-from-mcp-client" { - t.Fatalf("Overridden header Authorization not found in request to /api") - } - if pathHeaders["/apis"] == nil || pathHeaders["/apis"].Get("Authorization") != "Bearer a-token-from-mcp-client" { - t.Fatalf("Overridden header Authorization not found in request to /apis") - } - if pathHeaders["/api/v1"] == nil || pathHeaders["/api/v1"].Get("Authorization") != "Bearer a-token-from-mcp-client" { - t.Fatalf("Overridden header Authorization not found in request to /api/v1") - } +} + +func (s *McpHeadersSuite) TearDownTest() { + s.BaseMcpSuite.TearDownTest() + if s.mockServer != nil { + s.mockServer.Close() + } +} + +func (s *McpHeadersSuite) TestAuthorizationHeaderPropagation() { + cases := []string{"kubernetes-authorization", "Authorization"} + for _, header := range cases { + s.InitMcpClient(transport.WithHTTPHeaders(map[string]string{header: "Bearer a-token-from-mcp-client"})) + _, _ = s.CallTool("pods_list", map[string]interface{}{}) + s.Require().Greater(len(s.pathHeaders), 0, "No requests were made to Kube API") + s.Run("DiscoveryClient propagates "+header+" header to Kube API", func() { + s.Require().NotNil(s.pathHeaders["/api"], "No requests were made to /api") + s.Equal("Bearer a-token-from-mcp-client", s.pathHeaders["/api"].Get("Authorization"), "Overridden header Authorization not found in request to /api") + s.Require().NotNil(s.pathHeaders["/apis"], "No requests were made to /apis") + s.Equal("Bearer a-token-from-mcp-client", s.pathHeaders["/apis"].Get("Authorization"), "Overridden header Authorization not found in request to /apis") + s.Require().NotNil(s.pathHeaders["/api/v1"], "No requests were made to /api/v1") + s.Equal("Bearer a-token-from-mcp-client", s.pathHeaders["/api/v1"].Get("Authorization"), "Overridden header Authorization not found in request to /api/v1") }) - t.Run("DynamicClient propagates headers to Kube API", func(t *testing.T) { - if len(pathHeaders) == 0 { - t.Fatalf("No requests were made to Kube API") - } - if pathHeaders["/api/v1/namespaces/default/pods"] == nil || pathHeaders["/api/v1/namespaces/default/pods"].Get("Authorization") != "Bearer a-token-from-mcp-client" { - t.Fatalf("Overridden header Authorization not found in request to /api/v1/namespaces/default/pods") - } + s.Run("DynamicClient propagates "+header+" header to Kube API", func() { + s.Require().NotNil(s.pathHeaders["/api/v1/namespaces/default/pods"], "No requests were made to /api/v1/namespaces/default/pods") + s.Equal("Bearer a-token-from-mcp-client", s.pathHeaders["/api/v1/namespaces/default/pods"].Get("Authorization"), "Overridden header Authorization not found in request to /api/v1/namespaces/default/pods") }) - _, _ = c.callTool("pods_delete", map[string]interface{}{"name": "a-pod-to-delete"}) - t.Run("kubernetes.Interface propagates headers to Kube API", func(t *testing.T) { - if len(pathHeaders) == 0 { - t.Fatalf("No requests were made to Kube API") - } - if pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"] == nil || pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"].Get("Authorization") != "Bearer a-token-from-mcp-client" { - t.Fatalf("Overridden header Authorization not found in request to /api/v1/namespaces/default/pods/a-pod-to-delete") - } + _, _ = s.CallTool("pods_delete", map[string]interface{}{"name": "a-pod-to-delete"}) + s.Run("kubernetes.Interface propagates "+header+" header to Kube API", func() { + s.Require().NotNil(s.pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"], "No requests were made to /api/v1/namespaces/default/pods/a-pod-to-delete") + s.Equal("Bearer a-token-from-mcp-client", s.pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"].Get("Authorization"), "Overridden header Authorization not found in request to /api/v1/namespaces/default/pods/a-pod-to-delete") }) - }) + + } +} + +func TestMcpHeaders(t *testing.T) { + suite.Run(t, new(McpHeadersSuite)) }