diff --git a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift index fbd84d16..7aa32e79 100644 --- a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift +++ b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift @@ -8,14 +8,28 @@ import MLXVLM enum IntegrationTestModelIDs { static let llmModelId = "mlx-community/Qwen3-4B-Instruct-2507-4bit" static let vlmModelId = "mlx-community/Qwen3-VL-4B-Instruct-4bit" + + static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" + static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" + static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" + static let nemotronModelId = "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-4bit" + static let qwen35ModelId = "mlx-community/Qwen3.5-2B-4bit" } actor IntegrationTestModels { static let shared = IntegrationTestModels() + private init() {} + private var llmTask: Task? private var vlmTask: Task? + private var lfm2Task: Task? + private var glm4Task: Task? + private var mistral3Task: Task? + private var nemotronTask: Task? + private var qwen35Task: Task? + func llmContainer() async throws -> ModelContainer { if let task = llmTask { return try await task.value @@ -43,4 +57,74 @@ actor IntegrationTestModels { vlmTask = task return try await task.value } + + func lfm2Container() async throws -> ModelContainer { + if let task = lfm2Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.lfm2ModelId) + ) + } + lfm2Task = task + return try await task.value + } + + func glm4Container() async throws -> ModelContainer { + if let task = glm4Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.glm4ModelId) + ) + } + glm4Task = task + return try await task.value + } + + func mistral3Container() async throws -> ModelContainer { + if let task = mistral3Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.mistral3ModelId) + ) + } + mistral3Task = task + return try await task.value + } + + func nemotronContainer() async throws -> ModelContainer { + if let task = nemotronTask { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.nemotronModelId) + ) + } + nemotronTask = task + return try await task.value + } + + func qwen35Container() async throws -> ModelContainer { + if let task = qwen35Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.qwen35ModelId) + ) + } + qwen35Task = task + return try await task.value + } } diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index 3f40d988..4d04b6f8 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -18,18 +18,6 @@ import XCTest /// - GLM4: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/glm47.py public class ToolCallIntegrationTests: XCTestCase { - // MARK: - Model IDs - - static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" - static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" - static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" - - // MARK: - Shared State - - nonisolated(unsafe) static var lfm2Container: ModelContainer? - nonisolated(unsafe) static var glm4Container: ModelContainer? - nonisolated(unsafe) static var mistral3Container: ModelContainer? - // MARK: - Tool Schema static let weatherToolSchema: [[String: any Sendable]] = [ @@ -57,60 +45,62 @@ public class ToolCallIntegrationTests: XCTestCase { ] ] - // MARK: - Setup + // MARK: - Model Loading - override public class func setUp() { - super.setUp() - - let lfm2Expectation = XCTestExpectation(description: "Load LFM2") - let glm4Expectation = XCTestExpectation(description: "Load GLM4") - let mistral3Expectation = XCTestExpectation(description: "Load Mistral3") + private var lfm2Container: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.lfm2Container() + } catch { + throw XCTSkip("LFM2 model not available: \(error)") + } + } + } - Task { + private var glm4Container: ModelContainer { + get async throws { do { - lfm2Container = try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: lfm2ModelId) - ) + return try await IntegrationTestModels.shared.glm4Container() } catch { - print("Failed to load LFM2: \(error)") + throw XCTSkip("GLM4 model not available: \(error)") } - lfm2Expectation.fulfill() } + } - Task { + private var mistral3Container: ModelContainer { + get async throws { do { - glm4Container = try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: glm4ModelId) - ) + return try await IntegrationTestModels.shared.mistral3Container() } catch { - print("Failed to load GLM4: \(error)") + throw XCTSkip("Mistral3 model not available: \(error)") } - glm4Expectation.fulfill() } + } - Task { + private var nemotronContainer: ModelContainer { + get async throws { do { - mistral3Container = try await VLMModelFactory.shared.loadContainer( - configuration: .init(id: mistral3ModelId) - ) + return try await IntegrationTestModels.shared.nemotronContainer() } catch { - print("Failed to load Mistral3: \(error)") + throw XCTSkip("Nemotron model not available: \(error)") } - mistral3Expectation.fulfill() } + } - _ = XCTWaiter.wait( - for: [lfm2Expectation, glm4Expectation, mistral3Expectation], timeout: 600) + private var qwen35Container: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.qwen35Container() + } catch { + throw XCTSkip("Qwen3.5 model not available: \(error)") + } + } } // MARK: - LFM2 Tests func testLFM2ToolCallFormatAutoDetection() async throws { - guard let container = Self.lfm2Container else { - throw XCTSkip("LFM2 model not available") - } - - let config = await container.configuration + let config = try await lfm2Container.configuration XCTAssertEqual( config.toolCallFormat, .lfm2, "LFM2 model should auto-detect .lfm2 tool call format" @@ -118,9 +108,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testLFM2EndToEndToolCallGeneration() async throws { - guard let container = Self.lfm2Container else { - throw XCTSkip("LFM2 model not available") - } + let container = try await lfm2Container // Create input with tool schema let input = UserInput( @@ -160,11 +148,7 @@ public class ToolCallIntegrationTests: XCTestCase { // MARK: - GLM4 Tests func testGLM4ToolCallFormatAutoDetection() async throws { - guard let container = Self.glm4Container else { - throw XCTSkip("GLM4 model not available") - } - - let config = await container.configuration + let config = try await glm4Container.configuration XCTAssertEqual( config.toolCallFormat, .glm4, "GLM4 model should auto-detect .glm4 tool call format" @@ -172,9 +156,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testGLM4EndToEndToolCallGeneration() async throws { - guard let container = Self.glm4Container else { - throw XCTSkip("GLM4 model not available") - } + let container = try await glm4Container // Create input with tool schema let input = UserInput( @@ -214,11 +196,7 @@ public class ToolCallIntegrationTests: XCTestCase { // MARK: - Mistral3 Tests func testMistral3ToolCallFormatAutoDetection() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } - - let config = await container.configuration + let config = try await mistral3Container.configuration XCTAssertEqual( config.toolCallFormat, .mistral, "Mistral3 model should auto-detect .mistral tool call format" @@ -226,9 +204,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testMistral3EndToEndToolCallGeneration() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } + let container = try await mistral3Container let input = UserInput( chat: [ @@ -263,9 +239,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testMistral3MultipleToolCallGeneration() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } + let container = try await mistral3Container let multiToolSchema: [[String: any Sendable]] = Self.weatherToolSchema + [ @@ -325,6 +299,215 @@ public class ToolCallIntegrationTests: XCTestCase { } } + // MARK: - Nemotron Tests + + func testNemotronToolCallFormatAutoDetection() async throws { + let config = try await nemotronContainer.configuration + XCTAssertEqual( + config.toolCallFormat, .xmlFunction, + "Nemotron model should auto-detect .xmlFunction tool call format" + ) + } + + func testNemotronEndToEndToolCallGeneration() async throws { + let container = try await nemotronContainer + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." + ), + .user("What's the weather in Tokyo?"), + ], + tools: Self.weatherToolSchema, + additionalContext: ["enable_thinking": false] + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 150 + ) + + print("Nemotron Output: \(result)") + print("Nemotron Tool Calls: \(toolCalls)") + + if !toolCalls.isEmpty { + let toolCall = toolCalls.first! + XCTAssertEqual(toolCall.function.name, "get_weather") + if let location = toolCall.function.arguments["location"]?.asString { + XCTAssertTrue( + location.lowercased().contains("tokyo"), + "Expected location to contain 'Tokyo', got: \(location)" + ) + } + } + } + + func testNemotronMultipleToolCallGeneration() async throws { + let container = try await nemotronContainer + + let multiToolSchema: [[String: any Sendable]] = + Self.weatherToolSchema + [ + [ + "type": "function", + "function": [ + "name": "get_time", + "description": "Get the current time in a given timezone", + "parameters": [ + "type": "object", + "properties": [ + "timezone": [ + "type": "string", + "description": + "The timezone, e.g. America/New_York, Asia/Tokyo", + ] as [String: any Sendable] + ] as [String: any Sendable], + "required": ["timezone"], + ] as [String: any Sendable], + ] as [String: any Sendable], + ] + ] + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed." + ), + .user( + "What's the weather in Tokyo and what time is it there?" + ), + ], + tools: multiToolSchema, + additionalContext: ["enable_thinking": false] + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 600 + ) + + print("Nemotron Output: \(result)") + print("Nemotron Calls: \(toolCalls)") + + let validNames: Set = ["get_weather", "get_time"] + for toolCall in toolCalls { + XCTAssertTrue( + validNames.contains(toolCall.function.name), + "Unexpected tool call: \(toolCall.function.name)" + ) + } + + if toolCalls.count > 1 { + print("Successfully parsed \(toolCalls.count) tool calls from Nemotron") + } + } + + // MARK: - Qwen3.5 Tests + + func testQwen35ToolCallFormatAutoDetection() async throws { + let config = try await qwen35Container.configuration + XCTAssertEqual( + config.toolCallFormat, .xmlFunction, + "Qwen3.5 model should auto-detect .xmlFunction tool call format" + ) + } + + func testQwen35EndToEndToolCallGeneration() async throws { + let container = try await qwen35Container + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." + ), + .user("What's the weather in Tokyo?"), + ], + tools: Self.weatherToolSchema + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 150 + ) + + print("Qwen3.5 Output: \(result)") + print("Qwen3.5 Tool Calls: \(toolCalls)") + + if !toolCalls.isEmpty { + let toolCall = toolCalls.first! + XCTAssertEqual(toolCall.function.name, "get_weather") + if let location = toolCall.function.arguments["location"]?.asString { + XCTAssertTrue( + location.lowercased().contains("tokyo"), + "Expected location to contain 'Tokyo', got: \(location)" + ) + } + } + } + + func testQwen35MultipleToolCallGeneration() async throws { + let container = try await qwen35Container + + let multiToolSchema: [[String: any Sendable]] = + Self.weatherToolSchema + [ + [ + "type": "function", + "function": [ + "name": "get_time", + "description": "Get the current time in a given timezone", + "parameters": [ + "type": "object", + "properties": [ + "timezone": [ + "type": "string", + "description": + "The timezone, e.g. America/New_York, Asia/Tokyo", + ] as [String: any Sendable] + ] as [String: any Sendable], + "required": ["timezone"], + ] as [String: any Sendable], + ] as [String: any Sendable], + ] + ] + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed." + ), + .user( + "What's the weather in Tokyo and what time is it there?" + ), + ], + tools: multiToolSchema, + additionalContext: ["enable_thinking": true] + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 300 + ) + + print("Qwen3.5 Output: \(result)") + print("Qwen3.5 Calls: \(toolCalls)") + + let validNames: Set = ["get_weather", "get_time"] + for toolCall in toolCalls { + XCTAssertTrue( + validNames.contains(toolCall.function.name), + "Unexpected tool call: \(toolCall.function.name)" + ) + } + + if toolCalls.count > 1 { + print("Successfully parsed \(toolCalls.count) tool calls from Qwen3.5") + } + } + // MARK: - Helper Methods /// Generate text and collect any tool calls