diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index 8546c4aa..b91df691 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -1,23 +1,16 @@ package mcp import ( - "bytes" "context" "encoding/json" - "flag" "fmt" - "net/http/httptest" "os" "path/filepath" "runtime" - "strconv" "testing" "time" - "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" "github.com/pkg/errors" "github.com/spf13/afero" "github.com/stretchr/testify/suite" @@ -30,11 +23,7 @@ import ( "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" - clientcmdapi "k8s.io/client-go/tools/clientcmd/api" toolswatch "k8s.io/client-go/tools/watch" - "k8s.io/klog/v2" - "k8s.io/klog/v2/textlogger" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/tools/setup-envtest/env" @@ -45,7 +34,6 @@ import ( "github.com/containers/kubernetes-mcp-server/internal/test" "github.com/containers/kubernetes-mcp-server/pkg/config" - "github.com/containers/kubernetes-mcp-server/pkg/output" ) // envTest has an expensive setup, so we only want to do it once per entire test run. @@ -103,133 +91,6 @@ func TestMain(m *testing.M) { os.Exit(code) } -type mcpContext struct { - toolsets []string - listOutput output.Output - logLevel int - - staticConfig *config.StaticConfig - clientOptions []transport.ClientOption - before func(*mcpContext) - after func(*mcpContext) - ctx context.Context - tempDir string - cancel context.CancelFunc - mcpServer *Server - mcpHttpServer *httptest.Server - mcpClient *client.Client - klogState klog.State - logBuffer bytes.Buffer -} - -func (c *mcpContext) beforeEach(t *testing.T) { - var err error - c.ctx, c.cancel = context.WithCancel(t.Context()) - c.tempDir = t.TempDir() - c.withKubeConfig(nil) - if c.staticConfig == nil { - c.staticConfig = config.Default() - // Default to use YAML output for lists (previously the default) - c.staticConfig.ListOutput = "yaml" - } - if c.toolsets != nil { - c.staticConfig.Toolsets = c.toolsets - - } - if c.listOutput != nil { - c.staticConfig.ListOutput = c.listOutput.GetName() - } - if c.before != nil { - c.before(c) - } - // Set up logging - c.klogState = klog.CaptureState() - flags := flag.NewFlagSet("test", flag.ContinueOnError) - klog.InitFlags(flags) - _ = flags.Set("v", strconv.Itoa(c.logLevel)) - klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(c.logLevel), textlogger.Output(&c.logBuffer)))) - // MCP Server - if c.mcpServer, err = NewServer(Configuration{StaticConfig: c.staticConfig}); err != nil { - t.Fatal(err) - return - } - c.mcpHttpServer = server.NewTestServer(c.mcpServer.server, server.WithSSEContextFunc(contextFunc)) - if c.mcpClient, err = client.NewSSEMCPClient(c.mcpHttpServer.URL+"/sse", c.clientOptions...); err != nil { - t.Fatal(err) - return - } - // MCP Client - if err = c.mcpClient.Start(c.ctx); err != nil { - t.Fatal(err) - return - } - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.33.7"} - _, err = c.mcpClient.Initialize(c.ctx, initRequest) - if err != nil { - t.Fatal(err) - return - } -} - -func (c *mcpContext) afterEach() { - if c.after != nil { - c.after(c) - } - c.cancel() - c.mcpServer.Close() - _ = c.mcpClient.Close() - c.mcpHttpServer.Close() - c.klogState.Restore() -} - -func testCaseWithContext(t *testing.T, mcpCtx *mcpContext, test func(c *mcpContext)) { - mcpCtx.beforeEach(t) - defer mcpCtx.afterEach() - test(mcpCtx) -} - -// withKubeConfig sets up a fake kubeconfig in the temp directory based on the provided rest.Config -func (c *mcpContext) withKubeConfig(rc *rest.Config) *clientcmdapi.Config { - fakeConfig := clientcmdapi.NewConfig() - fakeConfig.Clusters["fake"] = clientcmdapi.NewCluster() - fakeConfig.Clusters["fake"].Server = "https://127.0.0.1:6443" - fakeConfig.Clusters["additional-cluster"] = clientcmdapi.NewCluster() - fakeConfig.AuthInfos["fake"] = clientcmdapi.NewAuthInfo() - fakeConfig.AuthInfos["additional-auth"] = clientcmdapi.NewAuthInfo() - if rc != nil { - fakeConfig.Clusters["fake"].Server = rc.Host - fakeConfig.Clusters["fake"].CertificateAuthorityData = rc.CAData - fakeConfig.AuthInfos["fake"].ClientKeyData = rc.KeyData - fakeConfig.AuthInfos["fake"].ClientCertificateData = rc.CertData - } - fakeConfig.Contexts["fake-context"] = clientcmdapi.NewContext() - fakeConfig.Contexts["fake-context"].Cluster = "fake" - fakeConfig.Contexts["fake-context"].AuthInfo = "fake" - fakeConfig.Contexts["additional-context"] = clientcmdapi.NewContext() - fakeConfig.Contexts["additional-context"].Cluster = "additional-cluster" - fakeConfig.Contexts["additional-context"].AuthInfo = "additional-auth" - fakeConfig.CurrentContext = "fake-context" - kubeConfig := filepath.Join(c.tempDir, "config") - _ = clientcmd.WriteToFile(*fakeConfig, kubeConfig) - _ = os.Setenv("KUBECONFIG", kubeConfig) - if c.mcpServer != nil { - if err := c.mcpServer.reloadKubernetesClusterProvider(); err != nil { - panic(err) - } - } - return fakeConfig -} - -// callTool helper function to call a tool by name with arguments -func (c *mcpContext) callTool(name string, args map[string]interface{}) (*mcp.CallToolResult, error) { - callToolRequest := mcp.CallToolRequest{} - callToolRequest.Params.Name = name - callToolRequest.Params.Arguments = args - return c.mcpClient.CallTool(c.ctx, callToolRequest) -} - func restoreAuth(ctx context.Context) { kubernetesAdmin := kubernetes.NewForConfigOrDie(envTest.Config) // Authorization diff --git a/pkg/mcp/mcp_middleware_test.go b/pkg/mcp/mcp_middleware_test.go index 987bfe4f..ce88e7b4 100644 --- a/pkg/mcp/mcp_middleware_test.go +++ b/pkg/mcp/mcp_middleware_test.go @@ -1,68 +1,87 @@ package mcp import ( + "bytes" + "flag" "regexp" - "strings" + "strconv" "testing" "github.com/mark3labs/mcp-go/client/transport" + "github.com/stretchr/testify/suite" + "k8s.io/klog/v2" + "k8s.io/klog/v2/textlogger" ) -func TestToolCallLogging(t *testing.T) { - testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) { - _, _ = c.callTool("configuration_view", map[string]interface{}{ - "minified": false, - }) - t.Run("Logs tool name", func(t *testing.T) { - expectedLog := "mcp tool call: configuration_view(" - if !strings.Contains(c.logBuffer.String(), expectedLog) { - t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String()) - } - }) - t.Run("Logs tool call arguments", func(t *testing.T) { - expected := `"mcp tool call: configuration_view\((.+)\)"` - m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String()) - if len(m) != 2 { - t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String()) - } - if m[1] != "map[minified:false]" { - t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1]) - } - }) +type McpLoggingSuite struct { + BaseMcpSuite + klogState klog.State + logBuffer bytes.Buffer +} + +func (s *McpLoggingSuite) SetupTest() { + s.BaseMcpSuite.SetupTest() + s.klogState = klog.CaptureState() +} + +func (s *McpLoggingSuite) TearDownTest() { + s.BaseMcpSuite.TearDownTest() + s.klogState.Restore() +} + +func (s *McpLoggingSuite) SetLogLevel(level int) { + flags := flag.NewFlagSet("test", flag.ContinueOnError) + klog.InitFlags(flags) + _ = flags.Set("v", strconv.Itoa(level)) + klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(level), textlogger.Output(&s.logBuffer)))) +} + +func (s *McpLoggingSuite) TestLogsToolCall() { + s.SetLogLevel(5) + s.InitMcpClient() + _, err := s.CallTool("configuration_view", map[string]interface{}{"minified": false}) + s.Require().NoError(err, "call to tool configuration_view failed") + + s.Run("Logs tool name", func() { + s.Contains(s.logBuffer.String(), "mcp tool call: configuration_view(") }) - before := func(c *mcpContext) { - c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{ - "Accept-Encoding": "gzip", - "Authorization": "Bearer should-not-be-logged", - "authorization": "Bearer should-not-be-logged", - "a-loggable-header": "should-be-logged", - })) + s.Run("Logs tool call arguments", func() { + expected := `"mcp tool call: configuration_view\((.+)\)"` + m := regexp.MustCompile(expected).FindStringSubmatch(s.logBuffer.String()) + s.Len(m, 2, "Expected log entry to contain arguments") + s.Equal("map[minified:false]", m[1], "Expected log arguments to be 'map[minified:false]'") + }) +} + +func (s *McpLoggingSuite) TestLogsToolCallHeaders() { + s.SetLogLevel(7) + s.InitMcpClient(transport.WithHTTPHeaders(map[string]string{ + "Accept-Encoding": "gzip", + "Authorization": "Bearer should-not-be-logged", + "authorization": "Bearer should-not-be-logged", + "a-loggable-header": "should-be-logged", + })) + _, err := s.CallTool("configuration_view", map[string]interface{}{"minified": false}) + s.Require().NoError(err, "call to tool configuration_view failed") + + s.Run("Logs tool call headers", func() { + expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged" + s.Contains(s.logBuffer.String(), expectedLog, "Expected log to contain loggable header") + }) + sensitiveHeaders := []string{ + "Authorization:", + // TODO: Add more sensitive headers as needed } - testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) { - _, _ = c.callTool("configuration_view", map[string]interface{}{ - "minified": false, - }) - t.Run("Logs tool call headers", func(t *testing.T) { - expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged" - if !strings.Contains(c.logBuffer.String(), expectedLog) { - t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String()) - } - }) - sensitiveHeaders := []string{ - "Authorization:", - // TODO: Add more sensitive headers as needed + s.Run("Does not log sensitive headers", func() { + for _, header := range sensitiveHeaders { + s.NotContains(s.logBuffer.String(), header, "Log should not contain sensitive header") } - t.Run("Does not log sensitive headers", func(t *testing.T) { - for _, header := range sensitiveHeaders { - if strings.Contains(c.logBuffer.String(), header) { - t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String()) - } - } - }) - t.Run("Does not log sensitive header values", func(t *testing.T) { - if strings.Contains(c.logBuffer.String(), "should-not-be-logged") { - t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String()) - } - }) }) + s.Run("Does not log sensitive header values", func() { + s.NotContains(s.logBuffer.String(), "should-not-be-logged", "Log should not contain sensitive header value") + }) +} + +func TestMcpLogging(t *testing.T) { + suite.Run(t, new(McpLoggingSuite)) }