Skip to content

Commit 227e0c7

Browse files
committed
mcp: fix nil params crash vulnerability (#186)
Handle nil RawMessage and explicit JSON "null" in unmarshalParams to prevent server crashes. When JSON-RPC requests are sent without parameters or with null parameters, ensure orZero receives a valid struct instead of nil. - Add nullJSON constant for efficient null comparison - Handle nil RawMessage by creating empty parameter struct - Handle explicit "null" JSON using bytes.Equal - Add comprehensive test coverage for nil parameter scenarios
1 parent b34ba21 commit 227e0c7

File tree

2 files changed

+163
-10
lines changed

2 files changed

+163
-10
lines changed

mcp/shared.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package mcp
1111

1212
import (
13+
"bytes"
1314
"context"
1415
"encoding/json"
1516
"fmt"
@@ -33,6 +34,9 @@ var supportedProtocolVersions = []string{
3334
"2024-11-05",
3435
}
3536

37+
// Precomputed null JSON for efficient comparison (issue #186 fix)
38+
var nullJSON = []byte("null")
39+
3640
// A MethodHandler handles MCP messages.
3741
// For methods, exactly one of the return values must be nil.
3842
// For notifications, both must be nil.
@@ -217,21 +221,34 @@ func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHand
217221
flags: flags,
218222
unmarshalParams: func(m json.RawMessage) (Params, error) {
219223
var p P
220-
if m != nil {
221-
if err := json.Unmarshal(m, &p); err != nil {
222-
return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err)
224+
225+
// Handle nil RawMessage (params missing entirely)
226+
if m == nil {
227+
if flags&missingParamsOK == 0 {
228+
return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest)
229+
}
230+
return orZero[Params](new(T)), nil
231+
}
232+
233+
// Handle explicit JSON "null" (params set to null)
234+
if bytes.Equal(m, nullJSON) {
235+
if flags&missingParamsOK == 0 {
236+
return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest)
223237
}
238+
return orZero[Params](new(T)), nil
224239
}
225-
// We must check missingParamsOK here, in addition to checkRequest, to
226-
// catch the edge cases where "params" is set to JSON null.
227-
// See also https://go.dev/issue/33835.
228-
//
229-
// We need to ensure that p is non-null to guard against crashes, as our
230-
// internal code or externally provided handlers may assume that params
231-
// is non-null.
240+
241+
// Normal JSON unmarshaling
242+
if err := json.Unmarshal(m, &p); err != nil {
243+
return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err)
244+
}
245+
246+
// Final check after unmarshaling - this guards against crashes when p is nil
247+
// due to Go's JSON unmarshaling behavior with interface types
232248
if flags&missingParamsOK == 0 && p == nil {
233249
return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest)
234250
}
251+
235252
return orZero[Params](p), nil
236253
},
237254
handleMethod: MethodHandler[S](func(ctx context.Context, session S, _ string, params Params) (Result, error) {

mcp/shared_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package mcp
77
import (
88
"context"
99
"encoding/json"
10+
"fmt"
1011
"strings"
1112
"testing"
1213
)
@@ -88,3 +89,138 @@ func TestToolValidate(t *testing.T) {
8889
})
8990
}
9091
}
92+
93+
// TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams.
94+
// This addresses a vulnerability where missing or null parameters could crash the server.
95+
func TestNilParamsHandling(t *testing.T) {
96+
// Define test types for clarity
97+
type TestArgs struct {
98+
Name string `json:"name"`
99+
Value int `json:"value"`
100+
}
101+
type TestParams = *CallToolParamsFor[TestArgs]
102+
type TestResult = *CallToolResultFor[string]
103+
104+
// Simple test handler
105+
testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (TestResult, error) {
106+
result := "processed: " + params.Arguments.Name
107+
return &CallToolResultFor[string]{StructuredContent: result}, nil
108+
}
109+
110+
methodInfo := newMethodInfo(testHandler, missingParamsOK)
111+
112+
// Helper function to test that unmarshalParams doesn't panic
113+
mustNotPanic := func(t *testing.T, rawMsg json.RawMessage) Params {
114+
t.Helper()
115+
116+
defer func() {
117+
if r := recover(); r != nil {
118+
t.Fatalf("unmarshalParams panicked: %v", r)
119+
}
120+
}()
121+
122+
params, err := methodInfo.unmarshalParams(rawMsg)
123+
if err != nil {
124+
t.Fatalf("unmarshalParams failed: %v", err)
125+
}
126+
if params == nil {
127+
t.Fatal("unmarshalParams returned nil")
128+
}
129+
130+
// Verify the result can be used safely
131+
typedParams := params.(TestParams)
132+
_ = typedParams.Name
133+
_ = typedParams.Arguments.Name
134+
_ = typedParams.Arguments.Value
135+
136+
return params
137+
}
138+
139+
// Test different nil parameter scenarios
140+
t.Run("missing_params", func(t *testing.T) {
141+
mustNotPanic(t, nil)
142+
})
143+
144+
t.Run("explicit_null", func(t *testing.T) {
145+
mustNotPanic(t, json.RawMessage(`null`))
146+
})
147+
148+
t.Run("empty_object", func(t *testing.T) {
149+
mustNotPanic(t, json.RawMessage(`{}`))
150+
})
151+
152+
t.Run("valid_params", func(t *testing.T) {
153+
rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`)
154+
params := mustNotPanic(t, rawMsg)
155+
156+
// For valid params, also verify the values are parsed correctly
157+
typedParams := params.(TestParams)
158+
if typedParams.Name != "test" {
159+
t.Errorf("Expected name 'test', got %q", typedParams.Name)
160+
}
161+
if typedParams.Arguments.Name != "hello" {
162+
t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name)
163+
}
164+
if typedParams.Arguments.Value != 42 {
165+
t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value)
166+
}
167+
})
168+
}
169+
170+
// TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix
171+
func TestNilParamsEdgeCases(t *testing.T) {
172+
type TestArgs struct {
173+
Name string `json:"name"`
174+
Value int `json:"value"`
175+
}
176+
type TestParams = *CallToolParamsFor[TestArgs]
177+
178+
testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (*CallToolResultFor[string], error) {
179+
return &CallToolResultFor[string]{StructuredContent: "test"}, nil
180+
}
181+
182+
methodInfo := newMethodInfo(testHandler, missingParamsOK)
183+
184+
// These should fail normally, not be treated as nil params
185+
invalidCases := []json.RawMessage{
186+
json.RawMessage(""), // empty string - should error
187+
json.RawMessage("[]"), // array - should error
188+
json.RawMessage(`"null"`), // string "null" - should error
189+
json.RawMessage("0"), // number - should error
190+
json.RawMessage("false"), // boolean - should error
191+
}
192+
193+
for i, rawMsg := range invalidCases {
194+
t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) {
195+
params, err := methodInfo.unmarshalParams(rawMsg)
196+
if err == nil && params == nil {
197+
t.Error("Should not return nil params without error")
198+
}
199+
})
200+
}
201+
202+
// Test that methods without missingParamsOK flag properly reject nil params
203+
t.Run("reject_when_params_required", func(t *testing.T) {
204+
methodInfoStrict := newMethodInfo(testHandler, 0) // No missingParamsOK flag
205+
206+
testCases := []struct {
207+
name string
208+
params json.RawMessage
209+
}{
210+
{"nil_params", nil},
211+
{"null_params", json.RawMessage(`null`)},
212+
}
213+
214+
for _, tc := range testCases {
215+
t.Run(tc.name, func(t *testing.T) {
216+
_, err := methodInfoStrict.unmarshalParams(tc.params)
217+
if err == nil {
218+
t.Error("Expected error for required params, got nil")
219+
}
220+
if !strings.Contains(err.Error(), "missing required \"params\"") {
221+
t.Errorf("Expected 'missing required params' error, got: %v", err)
222+
}
223+
})
224+
}
225+
})
226+
}

0 commit comments

Comments
 (0)