Skip to content

Commit be62a09

Browse files
committed
Add comprehensive test coverage for nil params scenarios
Complements the methodFlags system from modelcontextprotocol#210 with additional unit tests: - Tests nil RawMessage and explicit JSON null handling in unmarshalParams - Tests edge cases with different JSON types (empty string, array, number, boolean) - Validates proper error handling for required vs optional params with methodFlags - Provides focused unit test coverage alongside existing conformance tests The tests verify that the panic vulnerability from modelcontextprotocol#186 is properly handled by the upstream methodFlags implementation.
1 parent b34ba21 commit be62a09

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

mcp/shared_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package mcp
77
import (
88
"context"
99
"encoding/json"
10+
"fmt"
1011
"strings"
1112
"testing"
1213
)
@@ -88,3 +89,146 @@ func TestToolValidate(t *testing.T) {
8889
})
8990
}
9091
}
92+
93+
// TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams.
94+
// This addresses a vulnerability where missing or null parameters could crash the server.
95+
func TestNilParamsHandling(t *testing.T) {
96+
// Define test types for clarity
97+
type TestArgs struct {
98+
Name string `json:"name"`
99+
Value int `json:"value"`
100+
}
101+
type TestParams = *CallToolParamsFor[TestArgs]
102+
type TestResult = *CallToolResultFor[string]
103+
104+
// Simple test handler
105+
testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (TestResult, error) {
106+
result := "processed: " + params.Arguments.Name
107+
return &CallToolResultFor[string]{StructuredContent: result}, nil
108+
}
109+
110+
methodInfo := newMethodInfo(testHandler, missingParamsOK)
111+
112+
// Helper function to test that unmarshalParams doesn't panic and handles nil gracefully
113+
mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params {
114+
t.Helper()
115+
116+
defer func() {
117+
if r := recover(); r != nil {
118+
t.Fatalf("unmarshalParams panicked: %v", r)
119+
}
120+
}()
121+
122+
params, err := methodInfo.unmarshalParams(rawMsg)
123+
if err != nil {
124+
t.Fatalf("unmarshalParams failed: %v", err)
125+
}
126+
127+
if expectNil {
128+
if params != nil {
129+
t.Fatalf("Expected nil params, got %v", params)
130+
}
131+
return params
132+
}
133+
134+
if params == nil {
135+
t.Fatal("unmarshalParams returned unexpected nil")
136+
}
137+
138+
// Verify the result can be used safely
139+
typedParams := params.(TestParams)
140+
_ = typedParams.Name
141+
_ = typedParams.Arguments.Name
142+
_ = typedParams.Arguments.Value
143+
144+
return params
145+
}
146+
147+
// Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil
148+
t.Run("missing_params", func(t *testing.T) {
149+
mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag
150+
})
151+
152+
t.Run("explicit_null", func(t *testing.T) {
153+
mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag
154+
})
155+
156+
t.Run("empty_object", func(t *testing.T) {
157+
mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params
158+
})
159+
160+
t.Run("valid_params", func(t *testing.T) {
161+
rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`)
162+
params := mustNotPanic(t, rawMsg, false)
163+
164+
// For valid params, also verify the values are parsed correctly
165+
typedParams := params.(TestParams)
166+
if typedParams.Name != "test" {
167+
t.Errorf("Expected name 'test', got %q", typedParams.Name)
168+
}
169+
if typedParams.Arguments.Name != "hello" {
170+
t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name)
171+
}
172+
if typedParams.Arguments.Value != 42 {
173+
t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value)
174+
}
175+
})
176+
}
177+
178+
// TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix
179+
func TestNilParamsEdgeCases(t *testing.T) {
180+
type TestArgs struct {
181+
Name string `json:"name"`
182+
Value int `json:"value"`
183+
}
184+
type TestParams = *CallToolParamsFor[TestArgs]
185+
186+
testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (*CallToolResultFor[string], error) {
187+
return &CallToolResultFor[string]{StructuredContent: "test"}, nil
188+
}
189+
190+
methodInfo := newMethodInfo(testHandler, missingParamsOK)
191+
192+
// These should fail normally, not be treated as nil params
193+
invalidCases := []json.RawMessage{
194+
json.RawMessage(""), // empty string - should error
195+
json.RawMessage("[]"), // array - should error
196+
json.RawMessage(`"null"`), // string "null" - should error
197+
json.RawMessage("0"), // number - should error
198+
json.RawMessage("false"), // boolean - should error
199+
}
200+
201+
for i, rawMsg := range invalidCases {
202+
t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) {
203+
params, err := methodInfo.unmarshalParams(rawMsg)
204+
if err == nil && params == nil {
205+
t.Error("Should not return nil params without error")
206+
}
207+
})
208+
}
209+
210+
// Test that methods without missingParamsOK flag properly reject nil params
211+
t.Run("reject_when_params_required", func(t *testing.T) {
212+
methodInfoStrict := newMethodInfo(testHandler, 0) // No missingParamsOK flag
213+
214+
testCases := []struct {
215+
name string
216+
params json.RawMessage
217+
}{
218+
{"nil_params", nil},
219+
{"null_params", json.RawMessage(`null`)},
220+
}
221+
222+
for _, tc := range testCases {
223+
t.Run(tc.name, func(t *testing.T) {
224+
_, err := methodInfoStrict.unmarshalParams(tc.params)
225+
if err == nil {
226+
t.Error("Expected error for required params, got nil")
227+
}
228+
if !strings.Contains(err.Error(), "missing required \"params\"") {
229+
t.Errorf("Expected 'missing required params' error, got: %v", err)
230+
}
231+
})
232+
}
233+
})
234+
}

0 commit comments

Comments
 (0)