Skip to content

Commit de829ee

Browse files
authored
Add configurable strict mode for MCP client initialization (#43)
* Add configurable strict mode for MCP client initialization * Rename checkCapability to validateServerCapability
1 parent 5551f66 commit de829ee

File tree

2 files changed

+154
-16
lines changed

2 files changed

+154
-16
lines changed

Sources/MCP/Client/Client.swift

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,29 @@ import class Foundation.JSONEncoder
77

88
/// Model Context Protocol client
99
public actor Client {
10+
/// The client configuration
11+
public struct Configuration: Hashable, Codable, Sendable {
12+
/// The default configuration.
13+
public static let `default` = Configuration(strict: false)
14+
15+
/// The strict configuration.
16+
public static let strict = Configuration(strict: true)
17+
18+
/// When strict mode is enabled, the client:
19+
/// - Requires server capabilities to be initialized before making requests
20+
/// - Rejects all requests that require capabilities before initialization
21+
///
22+
/// While the MCP specification requires servers to respond to initialize requests
23+
/// with their capabilities, some implementations may not follow this.
24+
/// Disabling strict mode allows the client to be more lenient with non-compliant
25+
/// servers, though this may lead to undefined behavior.
26+
public var strict: Bool
27+
28+
public init(strict: Bool = false) {
29+
self.strict = strict
30+
}
31+
}
32+
1033
/// Implementation information
1134
public struct Info: Hashable, Codable, Sendable {
1235
/// The client name
@@ -73,6 +96,8 @@ public actor Client {
7396

7497
/// The client capabilities
7598
public var capabilities: Client.Capabilities
99+
/// The client configuration
100+
public var configuration: Configuration
76101

77102
/// The server capabilities
78103
private var serverCapabilities: Server.Capabilities?
@@ -131,10 +156,12 @@ public actor Client {
131156

132157
public init(
133158
name: String,
134-
version: String
159+
version: String,
160+
configuration: Configuration = .default
135161
) {
136162
self.clientInfo = Client.Info(name: name, version: version)
137163
self.capabilities = Capabilities()
164+
self.configuration = configuration
138165
}
139166

140167
/// Connect to the server using the given transport
@@ -294,7 +321,7 @@ public actor Client {
294321
public func getPrompt(name: String, arguments: [String: Value]? = nil) async throws
295322
-> (description: String?, messages: [Prompt.Message])
296323
{
297-
_ = try checkCapability(\.prompts, "Prompts")
324+
try validateServerCapability(\.prompts, "Prompts")
298325
let request = GetPrompt.request(.init(name: name, arguments: arguments))
299326
let result = try await send(request)
300327
return (description: result.description, messages: result.messages)
@@ -303,7 +330,7 @@ public actor Client {
303330
public func listPrompts(cursor: String? = nil) async throws
304331
-> (prompts: [Prompt], nextCursor: String?)
305332
{
306-
_ = try checkCapability(\.prompts, "Prompts")
333+
try validateServerCapability(\.prompts, "Prompts")
307334
let request: Request<ListPrompts>
308335
if let cursor = cursor {
309336
request = ListPrompts.request(.init(cursor: cursor))
@@ -317,7 +344,7 @@ public actor Client {
317344
// MARK: - Resources
318345

319346
public func readResource(uri: String) async throws -> [Resource.Content] {
320-
_ = try checkCapability(\.resources, "Resources")
347+
try validateServerCapability(\.resources, "Resources")
321348
let request = ReadResource.request(.init(uri: uri))
322349
let result = try await send(request)
323350
return result.contents
@@ -326,7 +353,7 @@ public actor Client {
326353
public func listResources(cursor: String? = nil) async throws -> (
327354
resources: [Resource], nextCursor: String?
328355
) {
329-
_ = try checkCapability(\.resources, "Resources")
356+
try validateServerCapability(\.resources, "Resources")
330357
let request: Request<ListResources>
331358
if let cursor = cursor {
332359
request = ListResources.request(.init(cursor: cursor))
@@ -338,15 +365,15 @@ public actor Client {
338365
}
339366

340367
public func subscribeToResource(uri: String) async throws {
341-
_ = try checkCapability(\.resources?.subscribe, "Resource subscription")
368+
try validateServerCapability(\.resources?.subscribe, "Resource subscription")
342369
let request = ResourceSubscribe.request(.init(uri: uri))
343370
_ = try await send(request)
344371
}
345372

346373
// MARK: - Tools
347374

348375
public func listTools(cursor: String? = nil) async throws -> [Tool] {
349-
_ = try checkCapability(\.tools, "Tools")
376+
try validateServerCapability(\.tools, "Tools")
350377
let request: Request<ListTools>
351378
if let cursor = cursor {
352379
request = ListTools.request(.init(cursor: cursor))
@@ -360,7 +387,7 @@ public actor Client {
360387
public func callTool(name: String, arguments: [String: Value]? = nil) async throws -> (
361388
content: [Tool.Content], isError: Bool?
362389
) {
363-
_ = try checkCapability(\.tools, "Tools")
390+
try validateServerCapability(\.tools, "Tools")
364391
let request = CallTool.request(.init(name: name, arguments: arguments))
365392
let result = try await send(request)
366393
return (content: result.content, isError: result.isError)
@@ -410,15 +437,21 @@ public actor Client {
410437

411438
// MARK: -
412439

413-
private func checkCapability<T>(_ keyPath: KeyPath<Server.Capabilities, T?>, _ name: String)
414-
throws -> T
440+
/// Validate the server capabilities.
441+
/// Throws an error if the client is configured to be strict and the capability is not supported.
442+
private func validateServerCapability<T>(
443+
_ keyPath: KeyPath<Server.Capabilities, T?>,
444+
_ name: String
445+
)
446+
throws
415447
{
416-
guard let capabilities = serverCapabilities else {
417-
throw Error.methodNotFound("Server capabilities not initialized")
418-
}
419-
guard let value = capabilities[keyPath: keyPath] else {
420-
throw Error.methodNotFound("\(name) is not supported by the server")
448+
if configuration.strict {
449+
guard let capabilities = serverCapabilities else {
450+
throw Error.methodNotFound("Server capabilities not initialized")
451+
}
452+
guard capabilities[keyPath: keyPath] != nil else {
453+
throw Error.methodNotFound("\(name) is not supported by the server")
454+
}
421455
}
422-
return value
423456
}
424457
}

Tests/MCPTests/ClientTests.swift

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import Foundation
12
import Testing
23

34
@testable import MCP
@@ -121,4 +122,108 @@ struct ClientTests {
121122
#expect(Bool(false), "Expected MCP.Error")
122123
}
123124
}
125+
126+
@Test("Strict configuration - capabilities check")
127+
func testStrictConfiguration() async throws {
128+
let transport = MockTransport()
129+
let config = Client.Configuration.strict
130+
let client = Client(name: "TestClient", version: "1.0", configuration: config)
131+
132+
try await client.connect(transport: transport)
133+
134+
// Create a task for listPrompts
135+
let promptsTask = Task<Void, Swift.Error> {
136+
do {
137+
_ = try await client.listPrompts()
138+
#expect(Bool(false), "Expected listPrompts to fail in strict mode")
139+
} catch let error as Error {
140+
if case Error.methodNotFound = error {
141+
#expect(Bool(true))
142+
} else {
143+
#expect(Bool(false), "Expected methodNotFound error, got \(error)")
144+
}
145+
} catch {
146+
#expect(Bool(false), "Expected MCP.Error")
147+
}
148+
}
149+
150+
// Give it a short time to execute the task
151+
try await Task.sleep(for: .milliseconds(50))
152+
153+
// Cancel the task if it's still running
154+
promptsTask.cancel()
155+
156+
// Disconnect client
157+
await client.disconnect()
158+
try await Task.sleep(for: .milliseconds(50))
159+
}
160+
161+
@Test("Non-strict configuration - capabilities check")
162+
func testNonStrictConfiguration() async throws {
163+
let transport = MockTransport()
164+
let config = Client.Configuration.default
165+
let client = Client(name: "TestClient", version: "1.0", configuration: config)
166+
167+
try await client.connect(transport: transport)
168+
169+
// Wait a bit for any setup to complete
170+
try await Task.sleep(for: .milliseconds(10))
171+
172+
// Send the listPrompts request and immediately provide an error response
173+
let promptsTask = Task {
174+
do {
175+
// Start the request
176+
try await Task.sleep(for: .seconds(1))
177+
178+
// Get the last sent message and extract the request ID
179+
if let lastMessage = await transport.sentMessages.last,
180+
let data = lastMessage.data(using: .utf8),
181+
let decodedRequest = try? JSONDecoder().decode(
182+
Request<ListPrompts>.self, from: data)
183+
{
184+
185+
// Create an error response with the same ID
186+
let errorResponse = Response<ListPrompts>(
187+
id: decodedRequest.id,
188+
error: Error.methodNotFound("Test: Prompts capability not available")
189+
)
190+
try await transport.queueResponse(errorResponse)
191+
192+
// Try the request now that we have a response queued
193+
do {
194+
_ = try await client.listPrompts()
195+
#expect(Bool(false), "Expected listPrompts to fail in non-strict mode")
196+
} catch let error as Error {
197+
if case Error.methodNotFound = error {
198+
#expect(Bool(true))
199+
} else {
200+
#expect(Bool(false), "Expected methodNotFound error, got \(error)")
201+
}
202+
} catch {
203+
#expect(Bool(false), "Expected MCP.Error")
204+
}
205+
}
206+
} catch {
207+
// Ignore task cancellation
208+
if !(error is CancellationError) {
209+
throw error
210+
}
211+
}
212+
}
213+
214+
// Wait for the task to complete or timeout
215+
let timeoutTask = Task {
216+
try await Task.sleep(for: .milliseconds(500))
217+
promptsTask.cancel()
218+
}
219+
220+
// Wait for the task to complete
221+
_ = await promptsTask.result
222+
223+
// Cancel the timeout task
224+
timeoutTask.cancel()
225+
226+
// Disconnect client
227+
await client.disconnect()
228+
}
124229
}

0 commit comments

Comments
 (0)