Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions mcp/shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package mcp
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"
)
Expand Down Expand Up @@ -88,3 +89,146 @@ func TestToolValidate(t *testing.T) {
})
}
}

// TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams.
// This addresses a vulnerability where missing or null parameters could crash the server.
func TestNilParamsHandling(t *testing.T) {
// Define test types for clarity
type TestArgs struct {
Name string `json:"name"`
Value int `json:"value"`
}
type TestParams = *CallToolParamsFor[TestArgs]
type TestResult = *CallToolResultFor[string]

// Simple test handler
testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (TestResult, error) {
result := "processed: " + params.Arguments.Name
return &CallToolResultFor[string]{StructuredContent: result}, nil
}

methodInfo := newMethodInfo(testHandler, missingParamsOK)

// Helper function to test that unmarshalParams doesn't panic and handles nil gracefully
mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params {
t.Helper()

defer func() {
if r := recover(); r != nil {
t.Fatalf("unmarshalParams panicked: %v", r)
}
}()

params, err := methodInfo.unmarshalParams(rawMsg)
if err != nil {
t.Fatalf("unmarshalParams failed: %v", err)
}

if expectNil {
if params != nil {
t.Fatalf("Expected nil params, got %v", params)
}
return params
}

if params == nil {
t.Fatal("unmarshalParams returned unexpected nil")
}

// Verify the result can be used safely
typedParams := params.(TestParams)
_ = typedParams.Name
_ = typedParams.Arguments.Name
_ = typedParams.Arguments.Value

return params
}

// Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil
t.Run("missing_params", func(t *testing.T) {
mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag
})

t.Run("explicit_null", func(t *testing.T) {
mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag
})

t.Run("empty_object", func(t *testing.T) {
mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params
})

t.Run("valid_params", func(t *testing.T) {
rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`)
params := mustNotPanic(t, rawMsg, false)

// For valid params, also verify the values are parsed correctly
typedParams := params.(TestParams)
if typedParams.Name != "test" {
t.Errorf("Expected name 'test', got %q", typedParams.Name)
}
if typedParams.Arguments.Name != "hello" {
t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name)
}
if typedParams.Arguments.Value != 42 {
t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value)
}
})
}

// TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix
func TestNilParamsEdgeCases(t *testing.T) {
type TestArgs struct {
Name string `json:"name"`
Value int `json:"value"`
}
type TestParams = *CallToolParamsFor[TestArgs]

testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (*CallToolResultFor[string], error) {
return &CallToolResultFor[string]{StructuredContent: "test"}, nil
}

methodInfo := newMethodInfo(testHandler, missingParamsOK)

// These should fail normally, not be treated as nil params
invalidCases := []json.RawMessage{
json.RawMessage(""), // empty string - should error
json.RawMessage("[]"), // array - should error
json.RawMessage(`"null"`), // string "null" - should error
json.RawMessage("0"), // number - should error
json.RawMessage("false"), // boolean - should error
}

for i, rawMsg := range invalidCases {
t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) {
params, err := methodInfo.unmarshalParams(rawMsg)
if err == nil && params == nil {
t.Error("Should not return nil params without error")
}
})
}

// Test that methods without missingParamsOK flag properly reject nil params
t.Run("reject_when_params_required", func(t *testing.T) {
methodInfoStrict := newMethodInfo(testHandler, 0) // No missingParamsOK flag

testCases := []struct {
name string
params json.RawMessage
}{
{"nil_params", nil},
{"null_params", json.RawMessage(`null`)},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := methodInfoStrict.unmarshalParams(tc.params)
if err == nil {
t.Error("Expected error for required params, got nil")
}
if !strings.Contains(err.Error(), "missing required \"params\"") {
t.Errorf("Expected 'missing required params' error, got: %v", err)
}
})
}
})
}
Loading