@@ -7,6 +7,7 @@ package mcp
77import (
88 "context"
99 "encoding/json"
10+ "fmt"
1011 "strings"
1112 "testing"
1213)
@@ -88,3 +89,138 @@ func TestToolValidate(t *testing.T) {
8889 })
8990 }
9091}
92+
93+ // TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams.
94+ // This addresses a vulnerability where missing or null parameters could crash the server.
95+ func TestNilParamsHandling (t * testing.T ) {
96+ // Define test types for clarity
97+ type TestArgs struct {
98+ Name string `json:"name"`
99+ Value int `json:"value"`
100+ }
101+ type TestParams = * CallToolParamsFor [TestArgs ]
102+ type TestResult = * CallToolResultFor [string ]
103+
104+ // Simple test handler
105+ testHandler := func (ctx context.Context , ss * ServerSession , params TestParams ) (TestResult , error ) {
106+ result := "processed: " + params .Arguments .Name
107+ return & CallToolResultFor [string ]{StructuredContent : result }, nil
108+ }
109+
110+ methodInfo := newMethodInfo (testHandler , missingParamsOK )
111+
112+ // Helper function to test that unmarshalParams doesn't panic
113+ mustNotPanic := func (t * testing.T , rawMsg json.RawMessage ) Params {
114+ t .Helper ()
115+
116+ defer func () {
117+ if r := recover (); r != nil {
118+ t .Fatalf ("unmarshalParams panicked: %v" , r )
119+ }
120+ }()
121+
122+ params , err := methodInfo .unmarshalParams (rawMsg )
123+ if err != nil {
124+ t .Fatalf ("unmarshalParams failed: %v" , err )
125+ }
126+ if params == nil {
127+ t .Fatal ("unmarshalParams returned nil" )
128+ }
129+
130+ // Verify the result can be used safely
131+ typedParams := params .(TestParams )
132+ _ = typedParams .Name
133+ _ = typedParams .Arguments .Name
134+ _ = typedParams .Arguments .Value
135+
136+ return params
137+ }
138+
139+ // Test different nil parameter scenarios
140+ t .Run ("missing_params" , func (t * testing.T ) {
141+ mustNotPanic (t , nil )
142+ })
143+
144+ t .Run ("explicit_null" , func (t * testing.T ) {
145+ mustNotPanic (t , json .RawMessage (`null` ))
146+ })
147+
148+ t .Run ("empty_object" , func (t * testing.T ) {
149+ mustNotPanic (t , json .RawMessage (`{}` ))
150+ })
151+
152+ t .Run ("valid_params" , func (t * testing.T ) {
153+ rawMsg := json .RawMessage (`{"name":"test","arguments":{"name":"hello","value":42}}` )
154+ params := mustNotPanic (t , rawMsg )
155+
156+ // For valid params, also verify the values are parsed correctly
157+ typedParams := params .(TestParams )
158+ if typedParams .Name != "test" {
159+ t .Errorf ("Expected name 'test', got %q" , typedParams .Name )
160+ }
161+ if typedParams .Arguments .Name != "hello" {
162+ t .Errorf ("Expected argument name 'hello', got %q" , typedParams .Arguments .Name )
163+ }
164+ if typedParams .Arguments .Value != 42 {
165+ t .Errorf ("Expected argument value 42, got %d" , typedParams .Arguments .Value )
166+ }
167+ })
168+ }
169+
170+ // TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix
171+ func TestNilParamsEdgeCases (t * testing.T ) {
172+ type TestArgs struct {
173+ Name string `json:"name"`
174+ Value int `json:"value"`
175+ }
176+ type TestParams = * CallToolParamsFor [TestArgs ]
177+
178+ testHandler := func (ctx context.Context , ss * ServerSession , params TestParams ) (* CallToolResultFor [string ], error ) {
179+ return & CallToolResultFor [string ]{StructuredContent : "test" }, nil
180+ }
181+
182+ methodInfo := newMethodInfo (testHandler , missingParamsOK )
183+
184+ // These should fail normally, not be treated as nil params
185+ invalidCases := []json.RawMessage {
186+ json .RawMessage ("" ), // empty string - should error
187+ json .RawMessage ("[]" ), // array - should error
188+ json .RawMessage (`"null"` ), // string "null" - should error
189+ json .RawMessage ("0" ), // number - should error
190+ json .RawMessage ("false" ), // boolean - should error
191+ }
192+
193+ for i , rawMsg := range invalidCases {
194+ t .Run (fmt .Sprintf ("invalid_case_%d" , i ), func (t * testing.T ) {
195+ params , err := methodInfo .unmarshalParams (rawMsg )
196+ if err == nil && params == nil {
197+ t .Error ("Should not return nil params without error" )
198+ }
199+ })
200+ }
201+
202+ // Test that methods without missingParamsOK flag properly reject nil params
203+ t .Run ("reject_when_params_required" , func (t * testing.T ) {
204+ methodInfoStrict := newMethodInfo (testHandler , 0 ) // No missingParamsOK flag
205+
206+ testCases := []struct {
207+ name string
208+ params json.RawMessage
209+ }{
210+ {"nil_params" , nil },
211+ {"null_params" , json .RawMessage (`null` )},
212+ }
213+
214+ for _ , tc := range testCases {
215+ t .Run (tc .name , func (t * testing.T ) {
216+ _ , err := methodInfoStrict .unmarshalParams (tc .params )
217+ if err == nil {
218+ t .Error ("Expected error for required params, got nil" )
219+ }
220+ if ! strings .Contains (err .Error (), "missing required \" params\" " ) {
221+ t .Errorf ("Expected 'missing required params' error, got: %v" , err )
222+ }
223+ })
224+ }
225+ })
226+ }
0 commit comments