Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions pkg/mcp/mcp_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package mcp

import (
"regexp"
"strings"
"testing"

"github.com/mark3labs/mcp-go/client/transport"
)

func TestToolCallLogging(t *testing.T) {
testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) {
_, _ = c.callTool("configuration_view", map[string]interface{}{
"minified": false,
})
t.Run("Logs tool name", func(t *testing.T) {
expectedLog := "mcp tool call: configuration_view("
if !strings.Contains(c.logBuffer.String(), expectedLog) {
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
}
})
t.Run("Logs tool call arguments", func(t *testing.T) {
expected := `"mcp tool call: configuration_view\((.+)\)"`
m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String())
if len(m) != 2 {
t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String())
}
if m[1] != "map[minified:false]" {
t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1])
}
})
})
before := func(c *mcpContext) {
c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{
"Accept-Encoding": "gzip",
"Authorization": "Bearer should-not-be-logged",
"authorization": "Bearer should-not-be-logged",
"a-loggable-header": "should-be-logged",
}))
}
testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) {
_, _ = c.callTool("configuration_view", map[string]interface{}{
"minified": false,
})
t.Run("Logs tool call headers", func(t *testing.T) {
expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged"
if !strings.Contains(c.logBuffer.String(), expectedLog) {
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
}
})
sensitiveHeaders := []string{
"Authorization:",
// TODO: Add more sensitive headers as needed
}
t.Run("Does not log sensitive headers", func(t *testing.T) {
for _, header := range sensitiveHeaders {
if strings.Contains(c.logBuffer.String(), header) {
t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String())
}
}
})
t.Run("Does not log sensitive header values", func(t *testing.T) {
if strings.Contains(c.logBuffer.String(), "should-not-be-logged") {
t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String())
}
})
})
}
254 changes: 102 additions & 152 deletions pkg/mcp/mcp_tools_test.go
Original file line number Diff line number Diff line change
@@ -1,180 +1,130 @@
package mcp

import (
"regexp"
"strings"
"testing"

"github.com/mark3labs/mcp-go/client/transport"
"github.com/BurntSushi/toml"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/suite"
"k8s.io/utils/ptr"

"github.com/containers/kubernetes-mcp-server/internal/test"
"github.com/containers/kubernetes-mcp-server/pkg/config"
)

func TestUnrestricted(t *testing.T) {
testCase(t, func(c *mcpContext) {
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
t.Run("ListTools returns tools", func(t *testing.T) {
if err != nil {
t.Fatalf("call ListTools failed %v", err)
}
})
t.Run("Destructive tools ARE NOT read only", func(t *testing.T) {
for _, tool := range tools.Tools {
readOnly := ptr.Deref(tool.Annotations.ReadOnlyHint, false)
destructive := ptr.Deref(tool.Annotations.DestructiveHint, false)
if readOnly && destructive {
t.Errorf("Tool %s is read-only and destructive, which is not allowed", tool.Name)
}
}
})
// McpToolProcessingSuite tests MCP tool processing (isToolApplicable)
type McpToolProcessingSuite struct {
BaseMcpSuite
}

func (s *McpToolProcessingSuite) TestUnrestricted() {
s.InitMcpClient()

tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
s.Require().NotNil(tools)

s.Run("ListTools returns tools", func() {
s.NoError(err, "call ListTools failed")
s.NotNilf(tools, "list tools failed")
})

s.Run("Destructive tools ARE NOT read only", func() {
for _, tool := range tools.Tools {
readOnly := ptr.Deref(tool.Annotations.ReadOnlyHint, false)
destructive := ptr.Deref(tool.Annotations.DestructiveHint, false)
s.Falsef(readOnly && destructive, "Tool %s is read-only and destructive, which is not allowed", tool.Name)
}
})
}

func TestReadOnly(t *testing.T) {
readOnlyServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{ReadOnly: true} }
testCaseWithContext(t, &mcpContext{before: readOnlyServer}, func(c *mcpContext) {
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
t.Run("ListTools returns tools", func(t *testing.T) {
if err != nil {
t.Fatalf("call ListTools failed %v", err)
}
})
t.Run("ListTools returns only read-only tools", func(t *testing.T) {
for _, tool := range tools.Tools {
if tool.Annotations.ReadOnlyHint == nil || !*tool.Annotations.ReadOnlyHint {
t.Errorf("Tool %s is not read-only but should be", tool.Name)
}
if tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint {
t.Errorf("Tool %s is destructive but should not be in read-only mode", tool.Name)
}
}
})
func (s *McpToolProcessingSuite) TestReadOnly() {
s.Require().NoError(toml.Unmarshal([]byte(`
read_only = true
`), s.Cfg), "Expected to parse read only server config")
s.InitMcpClient()

tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
s.Require().NotNil(tools)

s.Run("ListTools returns tools", func() {
s.NoError(err, "call ListTools failed")
s.NotNilf(tools, "list tools failed")
})

s.Run("ListTools returns only read-only tools", func() {
for _, tool := range tools.Tools {
s.Falsef(tool.Annotations.ReadOnlyHint == nil || !*tool.Annotations.ReadOnlyHint,
"Tool %s is not read-only but should be", tool.Name)
s.Falsef(tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint,
"Tool %s is destructive but should not be in read-only mode", tool.Name)
}
})
}

func TestDisableDestructive(t *testing.T) {
disableDestructiveServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{DisableDestructive: true} }
testCaseWithContext(t, &mcpContext{before: disableDestructiveServer}, func(c *mcpContext) {
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
t.Run("ListTools returns tools", func(t *testing.T) {
if err != nil {
t.Fatalf("call ListTools failed %v", err)
}
})
t.Run("ListTools does not return destructive tools", func(t *testing.T) {
for _, tool := range tools.Tools {
if tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint {
t.Errorf("Tool %s is destructive but should not be", tool.Name)
}
}
})
func (s *McpToolProcessingSuite) TestDisableDestructive() {
s.Require().NoError(toml.Unmarshal([]byte(`
disable_destructive = true
`), s.Cfg), "Expected to parse disable destructive server config")
s.InitMcpClient()

tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
s.Require().NotNil(tools)

s.Run("ListTools returns tools", func() {
s.NoError(err, "call ListTools failed")
s.NotNilf(tools, "list tools failed")
})

s.Run("ListTools does not return destructive tools", func() {
for _, tool := range tools.Tools {
s.Falsef(tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint,
"Tool %s is destructive but should not be in disable_destructive mode", tool.Name)
}
})
}

func TestEnabledTools(t *testing.T) {
enabledToolsServer := test.Must(config.ReadToml([]byte(`
func (s *McpToolProcessingSuite) TestEnabledTools() {
s.Require().NoError(toml.Unmarshal([]byte(`
enabled_tools = [ "namespaces_list", "events_list" ]
`)))
testCaseWithContext(t, &mcpContext{staticConfig: enabledToolsServer}, func(c *mcpContext) {
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
t.Run("ListTools returns tools", func(t *testing.T) {
if err != nil {
t.Fatalf("call ListTools failed %v", err)
}
})
t.Run("ListTools returns only explicitly enabled tools", func(t *testing.T) {
if len(tools.Tools) != 2 {
t.Fatalf("ListTools should return 2 tools, got %d", len(tools.Tools))
}
for _, tool := range tools.Tools {
if tool.Name != "namespaces_list" && tool.Name != "events_list" {
t.Errorf("Tool %s is not enabled but should be", tool.Name)
}
}
})
`), s.Cfg), "Expected to parse enabled tools server config")
s.InitMcpClient()

tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
s.Require().NotNil(tools)

s.Run("ListTools returns tools", func() {
s.NoError(err, "call ListTools failed")
s.NotNilf(tools, "list tools failed")
})
}

func TestDisabledTools(t *testing.T) {
testCaseWithContext(t, &mcpContext{
staticConfig: &config.StaticConfig{
DisabledTools: []string{"namespaces_list", "events_list"},
},
}, func(c *mcpContext) {
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
t.Run("ListTools returns tools", func(t *testing.T) {
if err != nil {
t.Fatalf("call ListTools failed %v", err)
}
})
t.Run("ListTools does not return disabled tools", func(t *testing.T) {
for _, tool := range tools.Tools {
if tool.Name == "namespaces_list" || tool.Name == "events_list" {
t.Errorf("Tool %s is not disabled but should be", tool.Name)
}
}
})
s.Run("ListTools returns only explicitly enabled tools", func() {
s.Len(tools.Tools, 2, "ListTools should return exactly 2 tools")
for _, tool := range tools.Tools {
s.Falsef(tool.Name != "namespaces_list" && tool.Name != "events_list",
"Tool %s is not enabled but should be", tool.Name)
}
})
}

func TestToolCallLogging(t *testing.T) {
testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) {
_, _ = c.callTool("configuration_view", map[string]interface{}{
"minified": false,
})
t.Run("Logs tool name", func(t *testing.T) {
expectedLog := "mcp tool call: configuration_view("
if !strings.Contains(c.logBuffer.String(), expectedLog) {
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
}
})
t.Run("Logs tool call arguments", func(t *testing.T) {
expected := `"mcp tool call: configuration_view\((.+)\)"`
m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String())
if len(m) != 2 {
t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String())
}
if m[1] != "map[minified:false]" {
t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1])
}
})
func (s *McpToolProcessingSuite) TestDisabledTools() {
s.Require().NoError(toml.Unmarshal([]byte(`
disabled_tools = [ "namespaces_list", "events_list" ]
`), s.Cfg), "Expected to parse disabled tools server config")
s.InitMcpClient()

tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
s.Require().NotNil(tools)

s.Run("ListTools returns tools", func() {
s.NoError(err, "call ListTools failed")
s.NotNilf(tools, "list tools failed")
})
before := func(c *mcpContext) {
c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{
"Accept-Encoding": "gzip",
"Authorization": "Bearer should-not-be-logged",
"authorization": "Bearer should-not-be-logged",
"a-loggable-header": "should-be-logged",
}))
}
testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) {
_, _ = c.callTool("configuration_view", map[string]interface{}{
"minified": false,
})
t.Run("Logs tool call headers", func(t *testing.T) {
expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged"
if !strings.Contains(c.logBuffer.String(), expectedLog) {
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
}
})
sensitiveHeaders := []string{
"Authorization:",
// TODO: Add more sensitive headers as needed

s.Run("ListTools does not return disabled tools", func() {
for _, tool := range tools.Tools {
s.Falsef(tool.Name == "namespaces_list" || tool.Name == "events_list",
"Tool %s is not disabled but should be", tool.Name)
}
t.Run("Does not log sensitive headers", func(t *testing.T) {
for _, header := range sensitiveHeaders {
if strings.Contains(c.logBuffer.String(), header) {
t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String())
}
}
})
t.Run("Does not log sensitive header values", func(t *testing.T) {
if strings.Contains(c.logBuffer.String(), "should-not-be-logged") {
t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String())
}
})
})
}

func TestMcpToolProcessing(t *testing.T) {
suite.Run(t, new(McpToolProcessingSuite))
}
Loading