Skip to content

Commit 2fd8098

Browse files
committed
test(mcp):update mcp processing tests to use testify and improve readability
Signed-off-by: Marc Nuri <[email protected]>
1 parent 07783a4 commit 2fd8098

File tree

2 files changed

+170
-152
lines changed

2 files changed

+170
-152
lines changed

pkg/mcp/mcp_middleware_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package mcp
2+
3+
import (
4+
"regexp"
5+
"strings"
6+
"testing"
7+
8+
"github.com/mark3labs/mcp-go/client/transport"
9+
)
10+
11+
func TestToolCallLogging(t *testing.T) {
12+
testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) {
13+
_, _ = c.callTool("configuration_view", map[string]interface{}{
14+
"minified": false,
15+
})
16+
t.Run("Logs tool name", func(t *testing.T) {
17+
expectedLog := "mcp tool call: configuration_view("
18+
if !strings.Contains(c.logBuffer.String(), expectedLog) {
19+
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
20+
}
21+
})
22+
t.Run("Logs tool call arguments", func(t *testing.T) {
23+
expected := `"mcp tool call: configuration_view\((.+)\)"`
24+
m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String())
25+
if len(m) != 2 {
26+
t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String())
27+
}
28+
if m[1] != "map[minified:false]" {
29+
t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1])
30+
}
31+
})
32+
})
33+
before := func(c *mcpContext) {
34+
c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{
35+
"Accept-Encoding": "gzip",
36+
"Authorization": "Bearer should-not-be-logged",
37+
"authorization": "Bearer should-not-be-logged",
38+
"a-loggable-header": "should-be-logged",
39+
}))
40+
}
41+
testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) {
42+
_, _ = c.callTool("configuration_view", map[string]interface{}{
43+
"minified": false,
44+
})
45+
t.Run("Logs tool call headers", func(t *testing.T) {
46+
expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged"
47+
if !strings.Contains(c.logBuffer.String(), expectedLog) {
48+
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
49+
}
50+
})
51+
sensitiveHeaders := []string{
52+
"Authorization:",
53+
// TODO: Add more sensitive headers as needed
54+
}
55+
t.Run("Does not log sensitive headers", func(t *testing.T) {
56+
for _, header := range sensitiveHeaders {
57+
if strings.Contains(c.logBuffer.String(), header) {
58+
t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String())
59+
}
60+
}
61+
})
62+
t.Run("Does not log sensitive header values", func(t *testing.T) {
63+
if strings.Contains(c.logBuffer.String(), "should-not-be-logged") {
64+
t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String())
65+
}
66+
})
67+
})
68+
}

pkg/mcp/mcp_tools_test.go

Lines changed: 102 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,130 @@
11
package mcp
22

33
import (
4-
"regexp"
5-
"strings"
64
"testing"
75

8-
"github.com/mark3labs/mcp-go/client/transport"
6+
"github.com/BurntSushi/toml"
97
"github.com/mark3labs/mcp-go/mcp"
8+
"github.com/stretchr/testify/suite"
109
"k8s.io/utils/ptr"
11-
12-
"github.com/containers/kubernetes-mcp-server/internal/test"
13-
"github.com/containers/kubernetes-mcp-server/pkg/config"
1410
)
1511

16-
func TestUnrestricted(t *testing.T) {
17-
testCase(t, func(c *mcpContext) {
18-
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
19-
t.Run("ListTools returns tools", func(t *testing.T) {
20-
if err != nil {
21-
t.Fatalf("call ListTools failed %v", err)
22-
}
23-
})
24-
t.Run("Destructive tools ARE NOT read only", func(t *testing.T) {
25-
for _, tool := range tools.Tools {
26-
readOnly := ptr.Deref(tool.Annotations.ReadOnlyHint, false)
27-
destructive := ptr.Deref(tool.Annotations.DestructiveHint, false)
28-
if readOnly && destructive {
29-
t.Errorf("Tool %s is read-only and destructive, which is not allowed", tool.Name)
30-
}
31-
}
32-
})
12+
// McpToolProcessingSuite tests MCP tool processing (isToolApplicable)
13+
type McpToolProcessingSuite struct {
14+
BaseMcpSuite
15+
}
16+
17+
func (s *McpToolProcessingSuite) TestUnrestricted() {
18+
s.InitMcpClient()
19+
20+
tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
21+
s.Require().NotNil(tools)
22+
23+
s.Run("ListTools returns tools", func() {
24+
s.NoError(err, "call ListTools failed")
25+
s.NotNilf(tools, "list tools failed")
26+
})
27+
28+
s.Run("Destructive tools ARE NOT read only", func() {
29+
for _, tool := range tools.Tools {
30+
readOnly := ptr.Deref(tool.Annotations.ReadOnlyHint, false)
31+
destructive := ptr.Deref(tool.Annotations.DestructiveHint, false)
32+
s.Falsef(readOnly && destructive, "Tool %s is read-only and destructive, which is not allowed", tool.Name)
33+
}
3334
})
3435
}
3536

36-
func TestReadOnly(t *testing.T) {
37-
readOnlyServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{ReadOnly: true} }
38-
testCaseWithContext(t, &mcpContext{before: readOnlyServer}, func(c *mcpContext) {
39-
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
40-
t.Run("ListTools returns tools", func(t *testing.T) {
41-
if err != nil {
42-
t.Fatalf("call ListTools failed %v", err)
43-
}
44-
})
45-
t.Run("ListTools returns only read-only tools", func(t *testing.T) {
46-
for _, tool := range tools.Tools {
47-
if tool.Annotations.ReadOnlyHint == nil || !*tool.Annotations.ReadOnlyHint {
48-
t.Errorf("Tool %s is not read-only but should be", tool.Name)
49-
}
50-
if tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint {
51-
t.Errorf("Tool %s is destructive but should not be in read-only mode", tool.Name)
52-
}
53-
}
54-
})
37+
func (s *McpToolProcessingSuite) TestReadOnly() {
38+
s.Require().NoError(toml.Unmarshal([]byte(`
39+
read_only = true
40+
`), s.Cfg), "Expected to parse read only server config")
41+
s.InitMcpClient()
42+
43+
tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
44+
s.Require().NotNil(tools)
45+
46+
s.Run("ListTools returns tools", func() {
47+
s.NoError(err, "call ListTools failed")
48+
s.NotNilf(tools, "list tools failed")
49+
})
50+
51+
s.Run("ListTools returns only read-only tools", func() {
52+
for _, tool := range tools.Tools {
53+
s.Falsef(tool.Annotations.ReadOnlyHint == nil || !*tool.Annotations.ReadOnlyHint,
54+
"Tool %s is not read-only but should be", tool.Name)
55+
s.Falsef(tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint,
56+
"Tool %s is destructive but should not be in read-only mode", tool.Name)
57+
}
5558
})
5659
}
5760

58-
func TestDisableDestructive(t *testing.T) {
59-
disableDestructiveServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{DisableDestructive: true} }
60-
testCaseWithContext(t, &mcpContext{before: disableDestructiveServer}, func(c *mcpContext) {
61-
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
62-
t.Run("ListTools returns tools", func(t *testing.T) {
63-
if err != nil {
64-
t.Fatalf("call ListTools failed %v", err)
65-
}
66-
})
67-
t.Run("ListTools does not return destructive tools", func(t *testing.T) {
68-
for _, tool := range tools.Tools {
69-
if tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint {
70-
t.Errorf("Tool %s is destructive but should not be", tool.Name)
71-
}
72-
}
73-
})
61+
func (s *McpToolProcessingSuite) TestDisableDestructive() {
62+
s.Require().NoError(toml.Unmarshal([]byte(`
63+
disable_destructive = true
64+
`), s.Cfg), "Expected to parse disable destructive server config")
65+
s.InitMcpClient()
66+
67+
tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
68+
s.Require().NotNil(tools)
69+
70+
s.Run("ListTools returns tools", func() {
71+
s.NoError(err, "call ListTools failed")
72+
s.NotNilf(tools, "list tools failed")
73+
})
74+
75+
s.Run("ListTools does not return destructive tools", func() {
76+
for _, tool := range tools.Tools {
77+
s.Falsef(tool.Annotations.DestructiveHint != nil && *tool.Annotations.DestructiveHint,
78+
"Tool %s is destructive but should not be in disable_destructive mode", tool.Name)
79+
}
7480
})
7581
}
7682

77-
func TestEnabledTools(t *testing.T) {
78-
enabledToolsServer := test.Must(config.ReadToml([]byte(`
83+
func (s *McpToolProcessingSuite) TestEnabledTools() {
84+
s.Require().NoError(toml.Unmarshal([]byte(`
7985
enabled_tools = [ "namespaces_list", "events_list" ]
80-
`)))
81-
testCaseWithContext(t, &mcpContext{staticConfig: enabledToolsServer}, func(c *mcpContext) {
82-
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
83-
t.Run("ListTools returns tools", func(t *testing.T) {
84-
if err != nil {
85-
t.Fatalf("call ListTools failed %v", err)
86-
}
87-
})
88-
t.Run("ListTools returns only explicitly enabled tools", func(t *testing.T) {
89-
if len(tools.Tools) != 2 {
90-
t.Fatalf("ListTools should return 2 tools, got %d", len(tools.Tools))
91-
}
92-
for _, tool := range tools.Tools {
93-
if tool.Name != "namespaces_list" && tool.Name != "events_list" {
94-
t.Errorf("Tool %s is not enabled but should be", tool.Name)
95-
}
96-
}
97-
})
86+
`), s.Cfg), "Expected to parse enabled tools server config")
87+
s.InitMcpClient()
88+
89+
tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
90+
s.Require().NotNil(tools)
91+
92+
s.Run("ListTools returns tools", func() {
93+
s.NoError(err, "call ListTools failed")
94+
s.NotNilf(tools, "list tools failed")
9895
})
99-
}
10096

101-
func TestDisabledTools(t *testing.T) {
102-
testCaseWithContext(t, &mcpContext{
103-
staticConfig: &config.StaticConfig{
104-
DisabledTools: []string{"namespaces_list", "events_list"},
105-
},
106-
}, func(c *mcpContext) {
107-
tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{})
108-
t.Run("ListTools returns tools", func(t *testing.T) {
109-
if err != nil {
110-
t.Fatalf("call ListTools failed %v", err)
111-
}
112-
})
113-
t.Run("ListTools does not return disabled tools", func(t *testing.T) {
114-
for _, tool := range tools.Tools {
115-
if tool.Name == "namespaces_list" || tool.Name == "events_list" {
116-
t.Errorf("Tool %s is not disabled but should be", tool.Name)
117-
}
118-
}
119-
})
97+
s.Run("ListTools returns only explicitly enabled tools", func() {
98+
s.Len(tools.Tools, 2, "ListTools should return exactly 2 tools")
99+
for _, tool := range tools.Tools {
100+
s.Falsef(tool.Name != "namespaces_list" && tool.Name != "events_list",
101+
"Tool %s is not enabled but should be", tool.Name)
102+
}
120103
})
121104
}
122105

123-
func TestToolCallLogging(t *testing.T) {
124-
testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) {
125-
_, _ = c.callTool("configuration_view", map[string]interface{}{
126-
"minified": false,
127-
})
128-
t.Run("Logs tool name", func(t *testing.T) {
129-
expectedLog := "mcp tool call: configuration_view("
130-
if !strings.Contains(c.logBuffer.String(), expectedLog) {
131-
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
132-
}
133-
})
134-
t.Run("Logs tool call arguments", func(t *testing.T) {
135-
expected := `"mcp tool call: configuration_view\((.+)\)"`
136-
m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String())
137-
if len(m) != 2 {
138-
t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String())
139-
}
140-
if m[1] != "map[minified:false]" {
141-
t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1])
142-
}
143-
})
106+
func (s *McpToolProcessingSuite) TestDisabledTools() {
107+
s.Require().NoError(toml.Unmarshal([]byte(`
108+
disabled_tools = [ "namespaces_list", "events_list" ]
109+
`), s.Cfg), "Expected to parse disabled tools server config")
110+
s.InitMcpClient()
111+
112+
tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{})
113+
s.Require().NotNil(tools)
114+
115+
s.Run("ListTools returns tools", func() {
116+
s.NoError(err, "call ListTools failed")
117+
s.NotNilf(tools, "list tools failed")
144118
})
145-
before := func(c *mcpContext) {
146-
c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{
147-
"Accept-Encoding": "gzip",
148-
"Authorization": "Bearer should-not-be-logged",
149-
"authorization": "Bearer should-not-be-logged",
150-
"a-loggable-header": "should-be-logged",
151-
}))
152-
}
153-
testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) {
154-
_, _ = c.callTool("configuration_view", map[string]interface{}{
155-
"minified": false,
156-
})
157-
t.Run("Logs tool call headers", func(t *testing.T) {
158-
expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged"
159-
if !strings.Contains(c.logBuffer.String(), expectedLog) {
160-
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
161-
}
162-
})
163-
sensitiveHeaders := []string{
164-
"Authorization:",
165-
// TODO: Add more sensitive headers as needed
119+
120+
s.Run("ListTools does not return disabled tools", func() {
121+
for _, tool := range tools.Tools {
122+
s.Falsef(tool.Name == "namespaces_list" || tool.Name == "events_list",
123+
"Tool %s is not disabled but should be", tool.Name)
166124
}
167-
t.Run("Does not log sensitive headers", func(t *testing.T) {
168-
for _, header := range sensitiveHeaders {
169-
if strings.Contains(c.logBuffer.String(), header) {
170-
t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String())
171-
}
172-
}
173-
})
174-
t.Run("Does not log sensitive header values", func(t *testing.T) {
175-
if strings.Contains(c.logBuffer.String(), "should-not-be-logged") {
176-
t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String())
177-
}
178-
})
179125
})
180126
}
127+
128+
func TestMcpToolProcessing(t *testing.T) {
129+
suite.Run(t, new(McpToolProcessingSuite))
130+
}

0 commit comments

Comments
 (0)