diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 5a4f1d51..cf7676a3 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -190,7 +190,7 @@ func toolCallLoggingMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFu klog.V(5).Infof("mcp tool call: %s(%v)", ctr.Params.Name, ctr.Params.Arguments) if ctr.Header != nil { buffer := bytes.NewBuffer(make([]byte, 0)) - if err := ctr.Header.Write(buffer); err == nil { + if err := ctr.Header.WriteSubset(buffer, map[string]bool{"Authorization": true, "authorization": true}); err == nil { klog.V(7).Infof("mcp tool call headers: %s", buffer) } } diff --git a/pkg/mcp/mcp_tools_test.go b/pkg/mcp/mcp_tools_test.go index 2280c493..4d12d306 100644 --- a/pkg/mcp/mcp_tools_test.go +++ b/pkg/mcp/mcp_tools_test.go @@ -1,6 +1,7 @@ package mcp import ( + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "k8s.io/utils/ptr" "regexp" @@ -140,16 +141,39 @@ func TestToolCallLogging(t *testing.T) { } }) }) - testCaseWithContext(t, &mcpContext{logLevel: 7}, func(c *mcpContext) { + before := func(c *mcpContext) { + c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{ + "Accept-Encoding": "gzip", + "Authorization": "Bearer should-not-be-logged", + "authorization": "Bearer should-not-be-logged", + "a-loggable-header": "should-be-logged", + })) + } + testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) { _, _ = c.callTool("configuration_view", map[string]interface{}{ "minified": false, }) t.Run("Logs tool call headers", func(t *testing.T) { - expectedLog := "mcp tool call headers: Accept-Encoding: gzip" + expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged" if !strings.Contains(c.logBuffer.String(), expectedLog) { t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String()) } }) - + sensitiveHeaders := []string{ + "Authorization", + // TODO: Add more sensitive headers as needed + } + t.Run("Does not log sensitive headers", func(t *testing.T) { + for _, header := range sensitiveHeaders { + if strings.Contains(c.logBuffer.String(), header) { + t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String()) + } + } + }) + t.Run("Does not log sensitive header values", func(t *testing.T) { + if strings.Contains(c.logBuffer.String(), "should-not-be-logged") { + t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String()) + } + }) }) }