Skip to content

Commit f73336f

Browse files
committed
Fix tool calling for Llama 3
Support multiple parallel tool calls and buffering for Llama 3 Llama 3 natively supports tool calling through an ipython environment which generates arrays for multiple parallel tool invocations. Depending on the model size and prompt, it generates either a JSON list of function objects or a python-style array of function calls. - Sets `startTag` to `<|python_tag|>` to ensure `ToolCallProcessor` correctly buffers tool output without leaking it to the streaming UI. - Upgrades `Llama3ToolCallParser` to parse multiple parallel tool calls from JSON array payloads `[{"name": ...}]` during `parseEOS`. - Upgrades `PythonicToolCallParser` to extract multiple sequential pythonic function calls `[func1(), func2()]` via `parseEOS`. - Refactors `PythonicToolCallParser` to use modern high-performance Swift 5.7+ Regex literals instead of legacy NSRegularExpression. - Add integration unit tests for both parsers to verify multi-call arrays.
1 parent f7a235d commit f73336f

File tree

6 files changed

+346
-24
lines changed

6 files changed

+346
-24
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,8 @@ public final class LLMModelFactory: ModelFactory {
534534

535535
// Auto-detect tool call format from model type if not explicitly set
536536
if mutableConfiguration.toolCallFormat == nil {
537-
mutableConfiguration.toolCallFormat = ToolCallFormat.infer(from: baseConfig.modelType)
537+
mutableConfiguration.toolCallFormat = ToolCallFormat.infer(
538+
from: baseConfig.modelType, configData: configData)
538539
}
539540

540541
// Load tokenizer and weights in parallel using async let.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
/// Parser for Llama 3 tool calls.
6+
/// Llama 3 often outputs inline JSON without standard start/end tags, or preceded by `<|python_tag|>`.
7+
/// It may also output native python function calls like `get_weather(location="San Francisco")`.
8+
public struct Llama3ToolCallParser: ToolCallParser, Sendable {
9+
public let startTag: String? = nil
10+
public let endTag: String? = nil
11+
12+
public init() {}
13+
14+
private struct LlamaFunction: Codable {
15+
let name: String
16+
let parameters: [String: JSONValue]?
17+
let arguments: [String: JSONValue]?
18+
}
19+
20+
public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
21+
var text = content
22+
23+
// If it outputs python tag, strip it
24+
if let range = text.range(of: "<|python_tag|>") {
25+
text = String(text[range.upperBound...])
26+
}
27+
28+
let jsonStr = text.trimmingCharacters(in: .whitespacesAndNewlines)
29+
30+
// Try JSON format first
31+
if let data = jsonStr.data(using: .utf8),
32+
let llamaFunc = try? JSONDecoder().decode(LlamaFunction.self, from: data)
33+
{
34+
let args = llamaFunc.parameters ?? llamaFunc.arguments ?? [:]
35+
36+
let function = ToolCall.Function(
37+
name: llamaFunc.name,
38+
arguments: args.mapValues { $0.anyValue }
39+
)
40+
return ToolCall(function: function)
41+
}
42+
43+
// Fallback to Pythonic format
44+
let pythonicParser = PythonicToolCallParser()
45+
return pythonicParser.parse(content: jsonStr, tools: tools)
46+
}
47+
48+
public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
49+
var text = toolCallBuffer
50+
51+
// If it outputs python tag, strip it
52+
if let range = text.range(of: "<|python_tag|>") {
53+
text = String(text[range.upperBound...])
54+
}
55+
56+
let jsonStr = text.trimmingCharacters(in: .whitespacesAndNewlines)
57+
58+
guard let data = jsonStr.data(using: .utf8) else {
59+
return []
60+
}
61+
62+
// Try JSON list format
63+
if let list = try? JSONDecoder().decode([LlamaFunction].self, from: data) {
64+
return list.map { llamaFunc in
65+
let args = llamaFunc.parameters ?? llamaFunc.arguments ?? [:]
66+
let function = ToolCall.Function(
67+
name: llamaFunc.name,
68+
arguments: args.mapValues { $0.anyValue }
69+
)
70+
return ToolCall(function: function)
71+
}
72+
}
73+
74+
// Try single JSON format
75+
if let llamaFunc = try? JSONDecoder().decode(LlamaFunction.self, from: data) {
76+
let args = llamaFunc.parameters ?? llamaFunc.arguments ?? [:]
77+
let function = ToolCall.Function(
78+
name: llamaFunc.name,
79+
arguments: args.mapValues { $0.anyValue }
80+
)
81+
return [ToolCall(function: function)]
82+
}
83+
84+
// Try Pythonic list like [func1(args), func2(args)] or single func1(args)
85+
let pythonicParser = PythonicToolCallParser()
86+
return pythonicParser.parseEOS(jsonStr, tools: tools)
87+
}
88+
}

Libraries/MLXLMCommon/Tool/Parsers/PythonicToolCallParser.swift

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,42 @@ public struct PythonicToolCallParser: ToolCallParser, Sendable {
6161
return ToolCall(function: .init(name: funcName, arguments: arguments))
6262
}
6363

64+
public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
65+
if let startTag {
66+
return
67+
toolCallBuffer
68+
.components(separatedBy: startTag)
69+
.filter { !$0.isEmpty }
70+
.flatMap { parseMultiple(content: $0, tools: tools) }
71+
} else {
72+
return parseMultiple(content: toolCallBuffer, tools: tools)
73+
}
74+
}
75+
76+
private func parseMultiple(content: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
77+
var text = content
78+
79+
if let end = endTag, let endRange = text.range(of: end) {
80+
text = String(text[..<endRange.lowerBound])
81+
}
82+
83+
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
84+
85+
let regex = #/(?s)(\w+)\((.*?)\)/#
86+
let matches = text.matches(of: regex)
87+
88+
var results: [ToolCall] = []
89+
for match in matches {
90+
let funcName = String(match.1)
91+
let argsString = String(match.2)
92+
let arguments = parseArguments(argsString, funcName: funcName, tools: tools)
93+
94+
results.append(ToolCall(function: .init(name: funcName, arguments: arguments)))
95+
}
96+
97+
return results
98+
}
99+
64100
/// Parse Pythonic keyword arguments: arg1='value1', arg2="value2", arg3=123
65101
private func parseArguments(
66102
_ argsString: String,
@@ -71,22 +107,12 @@ public struct PythonicToolCallParser: ToolCallParser, Sendable {
71107

72108
// Pattern for key=value pairs, handling quoted strings with possible commas inside
73109
// This handles: key='value', key="value", key=123, key=True, key=None
74-
let argPattern = #"(\w+)\s*=\s*('(?:[^'\\]|\\.)*'|"(?:[^"\\]|\\.)*"|[^,\)]+)"#
75-
76-
guard let regex = try? NSRegularExpression(pattern: argPattern, options: []) else {
77-
return arguments
78-
}
79-
80-
let matches = regex.matches(
81-
in: argsString, options: [], range: NSRange(argsString.startIndex..., in: argsString))
110+
let argRegex = #/(\w+)\s*=\s*('(?:[^'\\]|\\.)*'|"(?:[^"\\]|\\.)*"|[^,\)]+)/#
111+
let matches = argsString.matches(of: argRegex)
82112

83113
for match in matches {
84-
guard let keyRange = Range(match.range(at: 1), in: argsString),
85-
let valueRange = Range(match.range(at: 2), in: argsString)
86-
else { continue }
87-
88-
let key = String(argsString[keyRange])
89-
var value = String(argsString[valueRange]).trimmingCharacters(in: .whitespaces)
114+
let key = String(match.1)
115+
var value = String(match.2).trimmingCharacters(in: .whitespaces)
90116

91117
// Remove surrounding quotes if present
92118
if (value.hasPrefix("'") && value.hasSuffix("'"))

Libraries/MLXLMCommon/Tool/ToolCallFormat.swift

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
9494
/// Example: `[TOOL_CALLS]get_weather [ARGS]{"location": "Tokyo"}`
9595
case mistral
9696

97+
/// Llama 3 inline JSON format.
98+
/// Example: `<|python_tag|>{ "name": "func", "parameters": {...} }`
99+
case llama3
100+
97101
// MARK: - Factory Methods
98102

99103
/// Create the appropriate parser for this format.
@@ -117,6 +121,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
117121
return MiniMaxM2ToolCallParser()
118122
case .mistral:
119123
return MistralToolCallParser()
124+
case .llama3:
125+
return Llama3ToolCallParser()
120126
}
121127
}
122128

@@ -125,11 +131,35 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
125131
/// This method maps known model types to their corresponding tool call formats,
126132
/// enabling automatic format detection when loading models.
127133
///
128-
/// - Parameter modelType: The `model_type` value from config.json
134+
/// - Parameters:
135+
/// - modelType: The `model_type` value from config.json
136+
/// - configData: The raw config.json data for inspecting secondary signals (e.g. `rope_scaling` for Llama 3)
129137
/// - Returns: The appropriate `ToolCallFormat`, or `nil` to use the default format
130-
public static func infer(from modelType: String) -> ToolCallFormat? {
138+
public static func infer(from modelType: String, configData: Data? = nil) -> ToolCallFormat? {
131139
let type = modelType.lowercased()
132140

141+
// Llama family (need secondary signal for Llama 3 vs 1/2)
142+
if type == "llama" {
143+
guard let data = configData,
144+
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any]
145+
else { return nil }
146+
147+
// Secondary signal 1: vocab_size >= 128000 (Llama 3 uses 128256, Llama 2 uses 32000)
148+
if let vocabSize = json["vocab_size"] as? Int, vocabSize >= 128000 {
149+
return .llama3
150+
}
151+
152+
// Secondary signal 2: rope_scaling with rope_type == "llama3"
153+
if let ropeScaling = json["rope_scaling"] as? [String: Any],
154+
let ropeType = ropeScaling["rope_type"] as? String,
155+
ropeType == "llama3"
156+
{
157+
return .llama3
158+
}
159+
160+
return nil
161+
}
162+
133163
// LFM2 family (lfm2, lfm2_moe, lfm2_5, lfm25, etc.)
134164
if type.hasPrefix("lfm2") {
135165
return .lfm2

Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,72 @@ public class ToolCallProcessor {
101101
// MARK: - Private Methods
102102

103103
/// Process chunk for inline formats (no wrapper tags).
104+
///
105+
/// Uses brace counting to detect when output looks like a JSON tool call.
106+
/// While braces are unbalanced the content is buffered (returns `nil`)
107+
/// so partial JSON is never leaked to the UI.
104108
private func processInlineChunk(_ chunk: String) -> String? {
105-
toolCallBuffer += chunk
109+
switch state {
110+
case .normal:
111+
// Check if this chunk starts what looks like a JSON tool call
112+
if let braceIndex = chunk.firstIndex(of: "{") {
113+
let leading = String(chunk[..<braceIndex])
114+
let jsonPart = String(chunk[braceIndex...])
115+
toolCallBuffer = jsonPart
116+
state = .collectingToolCall
117+
118+
if let toolCall = parser.parse(content: toolCallBuffer, tools: tools) {
119+
toolCalls.append(toolCall)
120+
toolCallBuffer = ""
121+
state = .normal
122+
return leading.isEmpty ? nil : leading
123+
}
124+
125+
// Still collecting — check if braces are balanced (would mean parse
126+
// failed on complete JSON, so it's not a tool call)
127+
if jsonBracesBalanced(toolCallBuffer) {
128+
state = .normal
129+
let buffer = toolCallBuffer
130+
toolCallBuffer = ""
131+
return leading + buffer
132+
}
133+
134+
return leading.isEmpty ? nil : leading
135+
}
136+
137+
// No brace seen — pass through as regular text
138+
return chunk
106139

107-
if let toolCall = parser.parse(content: toolCallBuffer, tools: tools) {
108-
toolCalls.append(toolCall)
109-
toolCallBuffer = ""
140+
case .potentialToolCall, .collectingToolCall:
141+
toolCallBuffer += chunk
142+
143+
if let toolCall = parser.parse(content: toolCallBuffer, tools: tools) {
144+
toolCalls.append(toolCall)
145+
toolCallBuffer = ""
146+
state = .normal
147+
return nil
148+
}
149+
150+
// If braces are balanced but parse failed, this isn't a tool call — flush
151+
if jsonBracesBalanced(toolCallBuffer) {
152+
state = .normal
153+
let buffer = toolCallBuffer
154+
toolCallBuffer = ""
155+
return buffer
156+
}
157+
158+
// Still collecting
110159
return nil
111160
}
161+
}
112162

113-
// Return chunk as-is; caller handles incomplete inline tool calls
114-
return chunk
163+
/// Check whether open/close braces are balanced in the string.
164+
private func jsonBracesBalanced(_ text: String) -> Bool {
165+
var depth = 0
166+
for ch in text {
167+
if ch == "{" { depth += 1 } else if ch == "}" { depth -= 1 }
168+
}
169+
return depth == 0
115170
}
116171

117172
/// Process chunk for tagged formats.

0 commit comments

Comments
 (0)