Skip to content

Commit 75322c3

Browse files
authored
Ensure Request decoder handles missing/null params correctly (#22)
1 parent 2510dda commit 75322c3

File tree

2 files changed

+177
-18
lines changed

2 files changed

+177
-18
lines changed

Sources/MCP/Base/Messages.swift

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,31 @@ extension Request {
101101
method = try container.decode(String.self, forKey: .method)
102102

103103
if M.Parameters.self is NotRequired.Type {
104+
// For NotRequired parameters, use decodeIfPresent or init()
104105
params =
105106
(try container.decodeIfPresent(M.Parameters.self, forKey: .params)
106107
?? (M.Parameters.self as! NotRequired.Type).init() as! M.Parameters)
108+
} else if let value = try? container.decode(M.Parameters.self, forKey: .params) {
109+
// If params exists and can be decoded, use it
110+
params = value
111+
} else if !container.contains(.params)
112+
|| (try? container.decodeNil(forKey: .params)) == true
113+
{
114+
// If params is missing or explicitly null, use Empty for Empty parameters
115+
// or throw for non-Empty parameters
116+
if M.Parameters.self == Empty.self {
117+
params = Empty() as! M.Parameters
118+
} else {
119+
throw DecodingError.dataCorrupted(
120+
DecodingError.Context(
121+
codingPath: container.codingPath,
122+
debugDescription: "Missing required params field"))
123+
}
107124
} else {
108-
params = try container.decode(M.Parameters.self, forKey: .params)
125+
throw DecodingError.dataCorrupted(
126+
DecodingError.Context(
127+
codingPath: container.codingPath,
128+
debugDescription: "Invalid params field"))
109129
}
110130
}
111131
}

Tests/MCPTests/RequestTests.swift

Lines changed: 156 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,33 @@ struct RequestTests {
2323

2424
@Test("Request initialization with parameters")
2525
func testRequestInitialization() throws {
26-
let id: ID = "test-id"
27-
let params = TestMethod.Parameters(value: "test")
28-
let request = Request<TestMethod>(id: id, method: TestMethod.name, params: params)
26+
let id: ID = 1
27+
let params = CallTool.Parameters(name: "test-tool")
28+
let request = Request<CallTool>(id: id, method: CallTool.name, params: params)
2929

3030
#expect(request.id == id)
31-
#expect(request.method == TestMethod.name)
32-
#expect(request.params.value == "test")
31+
#expect(request.method == CallTool.name)
32+
#expect(request.params.name == "test-tool")
3333
}
3434

3535
@Test("Request encoding and decoding")
3636
func testRequestEncodingDecoding() throws {
37-
let request = TestMethod.request(id: "test-id", TestMethod.Parameters(value: "test"))
37+
let request = CallTool.request(id: 1, CallTool.Parameters(name: "test-tool"))
3838

3939
let encoder = JSONEncoder()
4040
let decoder = JSONDecoder()
4141

4242
let data = try encoder.encode(request)
43-
let decoded = try decoder.decode(Request<TestMethod>.self, from: data)
43+
let decoded = try decoder.decode(Request<CallTool>.self, from: data)
4444

4545
#expect(decoded.id == request.id)
4646
#expect(decoded.method == request.method)
47-
#expect(decoded.params.value == request.params.value)
47+
#expect(decoded.params.name == request.params.name)
4848
}
4949

5050
@Test("Empty parameters request encoding")
5151
func testEmptyParametersRequestEncoding() throws {
52-
let request = EmptyMethod.request(id: "test-id")
52+
let request = EmptyMethod.request(id: 1)
5353

5454
let encoder = JSONEncoder()
5555
let decoder = JSONDecoder()
@@ -66,59 +66,198 @@ struct RequestTests {
6666
func testEmptyParametersRequestDecoding() throws {
6767
// Create a minimal JSON string
6868
let jsonString = """
69-
{"jsonrpc":"2.0","id":"test-id","method":"empty.method"}
69+
{"jsonrpc":"2.0","id":1,"method":"empty.method"}
7070
"""
7171
let data = jsonString.data(using: .utf8)!
7272

7373
let decoder = JSONDecoder()
7474
let decoded = try decoder.decode(Request<EmptyMethod>.self, from: data)
7575

76-
#expect(decoded.id == "test-id")
76+
#expect(decoded.id == 1)
7777
#expect(decoded.method == EmptyMethod.name)
7878
}
7979

8080
@Test("NotRequired parameters request decoding - with params")
8181
func testNotRequiredParametersRequestDecodingWithParams() throws {
8282
// Test decoding when params field is present
8383
let jsonString = """
84-
{"jsonrpc":"2.0","id":"test-id","method":"ping","params":{}}
84+
{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}
8585
"""
8686
let data = jsonString.data(using: .utf8)!
8787

8888
let decoder = JSONDecoder()
8989
let decoded = try decoder.decode(Request<Ping>.self, from: data)
9090

91-
#expect(decoded.id == "test-id")
91+
#expect(decoded.id == 1)
9292
#expect(decoded.method == Ping.name)
9393
}
9494

9595
@Test("NotRequired parameters request decoding - without params")
9696
func testNotRequiredParametersRequestDecodingWithoutParams() throws {
9797
// Test decoding when params field is missing
9898
let jsonString = """
99-
{"jsonrpc":"2.0","id":"test-id","method":"ping"}
99+
{"jsonrpc":"2.0","id":1,"method":"ping"}
100100
"""
101101
let data = jsonString.data(using: .utf8)!
102102

103103
let decoder = JSONDecoder()
104104
let decoded = try decoder.decode(Request<Ping>.self, from: data)
105105

106-
#expect(decoded.id == "test-id")
106+
#expect(decoded.id == 1)
107107
#expect(decoded.method == Ping.name)
108108
}
109109

110110
@Test("NotRequired parameters request decoding - with null params")
111111
func testNotRequiredParametersRequestDecodingWithNullParams() throws {
112112
// Test decoding when params field is null
113113
let jsonString = """
114-
{"jsonrpc":"2.0","id":"test-id","method":"ping","params":null}
114+
{"jsonrpc":"2.0","id":1,"method":"ping","params":null}
115115
"""
116116
let data = jsonString.data(using: .utf8)!
117117

118118
let decoder = JSONDecoder()
119119
let decoded = try decoder.decode(Request<Ping>.self, from: data)
120120

121-
#expect(decoded.id == "test-id")
121+
#expect(decoded.id == 1)
122122
#expect(decoded.method == Ping.name)
123123
}
124+
125+
@Test("Required parameters request decoding - missing params")
126+
func testRequiredParametersRequestDecodingMissingParams() throws {
127+
let jsonString = """
128+
{"jsonrpc":"2.0","id":1,"method":"tools/call"}
129+
"""
130+
let data = jsonString.data(using: .utf8)!
131+
132+
let decoder = JSONDecoder()
133+
#expect(throws: DecodingError.self) {
134+
_ = try decoder.decode(Request<CallTool>.self, from: data)
135+
}
136+
}
137+
138+
@Test("Required parameters request decoding - null params")
139+
func testRequiredParametersRequestDecodingNullParams() throws {
140+
let jsonString = """
141+
{"jsonrpc":"2.0","id":1,"method":"tools/call","params":null}
142+
"""
143+
let data = jsonString.data(using: .utf8)!
144+
145+
let decoder = JSONDecoder()
146+
#expect(throws: DecodingError.self) {
147+
_ = try decoder.decode(Request<CallTool>.self, from: data)
148+
}
149+
}
150+
151+
@Test("Empty parameters request decoding - with null params")
152+
func testEmptyParametersRequestDecodingNullParams() throws {
153+
let jsonString = """
154+
{"jsonrpc":"2.0","id":1,"method":"empty.method","params":null}
155+
"""
156+
let data = jsonString.data(using: .utf8)!
157+
158+
let decoder = JSONDecoder()
159+
let decoded = try decoder.decode(Request<EmptyMethod>.self, from: data)
160+
161+
#expect(decoded.id == 1)
162+
#expect(decoded.method == EmptyMethod.name)
163+
}
164+
165+
@Test("Empty parameters request decoding - with empty object params")
166+
func testEmptyParametersRequestDecodingEmptyParams() throws {
167+
let jsonString = """
168+
{"jsonrpc":"2.0","id":1,"method":"empty.method","params":{}}
169+
"""
170+
let data = jsonString.data(using: .utf8)!
171+
172+
let decoder = JSONDecoder()
173+
let decoded = try decoder.decode(Request<EmptyMethod>.self, from: data)
174+
175+
#expect(decoded.id == 1)
176+
#expect(decoded.method == EmptyMethod.name)
177+
}
178+
179+
@Test("Initialize request decoding - requires params")
180+
func testInitializeRequestDecodingRequiresParams() throws {
181+
// Test missing params field
182+
let missingParams = """
183+
{"jsonrpc":"2.0","id":"test-id","method":"initialize"}
184+
"""
185+
let decoder = JSONDecoder()
186+
#expect(throws: DecodingError.self) {
187+
_ = try decoder.decode(
188+
Request<Initialize>.self, from: missingParams.data(using: .utf8)!)
189+
}
190+
191+
// Test null params
192+
let nullParams = """
193+
{"jsonrpc":"2.0","id":"test-id","method":"initialize","params":null}
194+
"""
195+
#expect(throws: DecodingError.self) {
196+
_ = try decoder.decode(Request<Initialize>.self, from: nullParams.data(using: .utf8)!)
197+
}
198+
199+
// Verify that empty object params works (since fields have defaults)
200+
let emptyParams = """
201+
{"jsonrpc":"2.0","id":"test-id","method":"initialize","params":{}}
202+
"""
203+
let decoded = try decoder.decode(
204+
Request<Initialize>.self, from: emptyParams.data(using: .utf8)!)
205+
#expect(decoded.params.protocolVersion == Version.latest)
206+
#expect(decoded.params.clientInfo.name == "unknown")
207+
}
208+
209+
@Test("Invalid parameters request decoding")
210+
func testInvalidParametersRequestDecoding() throws {
211+
let jsonString = """
212+
{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"invalid":"value"}}
213+
"""
214+
let data = jsonString.data(using: .utf8)!
215+
216+
let decoder = JSONDecoder()
217+
#expect(throws: DecodingError.self) {
218+
_ = try decoder.decode(Request<CallTool>.self, from: data)
219+
}
220+
}
221+
222+
@Test("NotRequired parameters request decoding")
223+
func testNotRequiredParametersRequestDecoding() throws {
224+
// Test with missing params
225+
let missingParams = """
226+
{"jsonrpc":"2.0","id":1,"method":"tools/list"}
227+
"""
228+
let decoder = JSONDecoder()
229+
let decodedMissing = try decoder.decode(
230+
Request<ListTools>.self,
231+
from: missingParams.data(using: .utf8)!)
232+
#expect(decodedMissing.id == 1)
233+
#expect(decodedMissing.method == ListTools.name)
234+
#expect(decodedMissing.params.cursor == nil)
235+
236+
// Test with null params
237+
let nullParams = """
238+
{"jsonrpc":"2.0","id":1,"method":"tools/list","params":null}
239+
"""
240+
let decodedNull = try decoder.decode(
241+
Request<ListTools>.self,
242+
from: nullParams.data(using: .utf8)!)
243+
#expect(decodedNull.params.cursor == nil)
244+
245+
// Test with empty object params
246+
let emptyParams = """
247+
{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}
248+
"""
249+
let decodedEmpty = try decoder.decode(
250+
Request<ListTools>.self,
251+
from: emptyParams.data(using: .utf8)!)
252+
#expect(decodedEmpty.params.cursor == nil)
253+
254+
// Test with provided cursor
255+
let withCursor = """
256+
{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{"cursor":"next-page"}}
257+
"""
258+
let decodedWithCursor = try decoder.decode(
259+
Request<ListTools>.self,
260+
from: withCursor.data(using: .utf8)!)
261+
#expect(decodedWithCursor.params.cursor == "next-page")
262+
}
124263
}

0 commit comments

Comments
 (0)