@@ -7,6 +7,7 @@ package mcp
77import (
88 "context"
99 "encoding/json"
10+ "fmt"
1011 "strings"
1112 "testing"
1213)
@@ -88,3 +89,146 @@ func TestToolValidate(t *testing.T) {
8889 })
8990 }
9091}
92+
93+ // TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams.
94+ // This addresses a vulnerability where missing or null parameters could crash the server.
95+ func TestNilParamsHandling (t * testing.T ) {
96+ // Define test types for clarity
97+ type TestArgs struct {
98+ Name string `json:"name"`
99+ Value int `json:"value"`
100+ }
101+ type TestParams = * CallToolParamsFor [TestArgs ]
102+ type TestResult = * CallToolResultFor [string ]
103+
104+ // Simple test handler
105+ testHandler := func (ctx context.Context , ss * ServerSession , params TestParams ) (TestResult , error ) {
106+ result := "processed: " + params .Arguments .Name
107+ return & CallToolResultFor [string ]{StructuredContent : result }, nil
108+ }
109+
110+ methodInfo := newMethodInfo (testHandler , missingParamsOK )
111+
112+ // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully
113+ mustNotPanic := func (t * testing.T , rawMsg json.RawMessage , expectNil bool ) Params {
114+ t .Helper ()
115+
116+ defer func () {
117+ if r := recover (); r != nil {
118+ t .Fatalf ("unmarshalParams panicked: %v" , r )
119+ }
120+ }()
121+
122+ params , err := methodInfo .unmarshalParams (rawMsg )
123+ if err != nil {
124+ t .Fatalf ("unmarshalParams failed: %v" , err )
125+ }
126+
127+ if expectNil {
128+ if params != nil {
129+ t .Fatalf ("Expected nil params, got %v" , params )
130+ }
131+ return params
132+ }
133+
134+ if params == nil {
135+ t .Fatal ("unmarshalParams returned unexpected nil" )
136+ }
137+
138+ // Verify the result can be used safely
139+ typedParams := params .(TestParams )
140+ _ = typedParams .Name
141+ _ = typedParams .Arguments .Name
142+ _ = typedParams .Arguments .Value
143+
144+ return params
145+ }
146+
147+ // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil
148+ t .Run ("missing_params" , func (t * testing.T ) {
149+ mustNotPanic (t , nil , true ) // Expect nil with missingParamsOK flag
150+ })
151+
152+ t .Run ("explicit_null" , func (t * testing.T ) {
153+ mustNotPanic (t , json .RawMessage (`null` ), true ) // Expect nil with missingParamsOK flag
154+ })
155+
156+ t .Run ("empty_object" , func (t * testing.T ) {
157+ mustNotPanic (t , json .RawMessage (`{}` ), false ) // Empty object should create valid params
158+ })
159+
160+ t .Run ("valid_params" , func (t * testing.T ) {
161+ rawMsg := json .RawMessage (`{"name":"test","arguments":{"name":"hello","value":42}}` )
162+ params := mustNotPanic (t , rawMsg , false )
163+
164+ // For valid params, also verify the values are parsed correctly
165+ typedParams := params .(TestParams )
166+ if typedParams .Name != "test" {
167+ t .Errorf ("Expected name 'test', got %q" , typedParams .Name )
168+ }
169+ if typedParams .Arguments .Name != "hello" {
170+ t .Errorf ("Expected argument name 'hello', got %q" , typedParams .Arguments .Name )
171+ }
172+ if typedParams .Arguments .Value != 42 {
173+ t .Errorf ("Expected argument value 42, got %d" , typedParams .Arguments .Value )
174+ }
175+ })
176+ }
177+
178+ // TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix
179+ func TestNilParamsEdgeCases (t * testing.T ) {
180+ type TestArgs struct {
181+ Name string `json:"name"`
182+ Value int `json:"value"`
183+ }
184+ type TestParams = * CallToolParamsFor [TestArgs ]
185+
186+ testHandler := func (ctx context.Context , ss * ServerSession , params TestParams ) (* CallToolResultFor [string ], error ) {
187+ return & CallToolResultFor [string ]{StructuredContent : "test" }, nil
188+ }
189+
190+ methodInfo := newMethodInfo (testHandler , missingParamsOK )
191+
192+ // These should fail normally, not be treated as nil params
193+ invalidCases := []json.RawMessage {
194+ json .RawMessage ("" ), // empty string - should error
195+ json .RawMessage ("[]" ), // array - should error
196+ json .RawMessage (`"null"` ), // string "null" - should error
197+ json .RawMessage ("0" ), // number - should error
198+ json .RawMessage ("false" ), // boolean - should error
199+ }
200+
201+ for i , rawMsg := range invalidCases {
202+ t .Run (fmt .Sprintf ("invalid_case_%d" , i ), func (t * testing.T ) {
203+ params , err := methodInfo .unmarshalParams (rawMsg )
204+ if err == nil && params == nil {
205+ t .Error ("Should not return nil params without error" )
206+ }
207+ })
208+ }
209+
210+ // Test that methods without missingParamsOK flag properly reject nil params
211+ t .Run ("reject_when_params_required" , func (t * testing.T ) {
212+ methodInfoStrict := newMethodInfo (testHandler , 0 ) // No missingParamsOK flag
213+
214+ testCases := []struct {
215+ name string
216+ params json.RawMessage
217+ }{
218+ {"nil_params" , nil },
219+ {"null_params" , json .RawMessage (`null` )},
220+ }
221+
222+ for _ , tc := range testCases {
223+ t .Run (tc .name , func (t * testing.T ) {
224+ _ , err := methodInfoStrict .unmarshalParams (tc .params )
225+ if err == nil {
226+ t .Error ("Expected error for required params, got nil" )
227+ }
228+ if ! strings .Contains (err .Error (), "missing required \" params\" " ) {
229+ t .Errorf ("Expected 'missing required params' error, got: %v" , err )
230+ }
231+ })
232+ }
233+ })
234+ }
0 commit comments