Skip to content
Merged
2 changes: 1 addition & 1 deletion Libraries/MLXLMCommon/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public func downloadModel(
case .id(let id, let revision):
// download the model weights
let repo = Hub.Repo(id: id)
let modelFiles = ["*.safetensors", "*.json"]
let modelFiles = ["*.safetensors", "*.json", "*.jinja"]
return try await hub.snapshot(
from: repo,
revision: revision,
Expand Down
47 changes: 28 additions & 19 deletions Libraries/MLXLMCommon/Tool/Parsers/PythonicToolCallParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,37 @@ public struct PythonicToolCallParser: ToolCallParser, Sendable {

text = text.trimmingCharacters(in: .whitespacesAndNewlines)

// Pattern: [function_name(args...)] or function_name(args...)
// Also handle multiple calls: [func1(args), func2(args)]
let pattern = #"\[?(\w+)\((.*?)\)\]?"#

guard
let regex = try? NSRegularExpression(
pattern: pattern, options: [.dotMatchesLineSeparators]),
let funcName: String
let argsString: String

// Required brackets pattern (matches Python reference: r"\[(\w+)\((.*?)\)\]")
// The required \] forces .*? to backtrack past nested ) inside argument values.
let bracketPattern = #"\[(\w+)\((.*?)\)\]"#
if let regex = try? NSRegularExpression(
pattern: bracketPattern, options: [.dotMatchesLineSeparators]),
let match = regex.firstMatch(
in: text, options: [], range: NSRange(text.startIndex..., in: text))
else { return nil }

// Extract function name
guard let nameRange = Range(match.range(at: 1), in: text) else { return nil }
let funcName = String(text[nameRange])

// Extract arguments string
guard let argsRange = Range(match.range(at: 2), in: text) else { return nil }
let argsString = String(text[argsRange])
in: text, options: [], range: NSRange(text.startIndex..., in: text)),
let nameRange = Range(match.range(at: 1), in: text),
let argsRange = Range(match.range(at: 2), in: text)
{
funcName = String(text[nameRange])
argsString = String(text[argsRange])
} else {
// Fallback for without-brackets case: use string indices to find the
// outermost parentheses, avoiding the greedy/non-greedy regex pitfall.
guard let openParen = text.firstIndex(of: "("),
let closeParen = text.lastIndex(of: ")")
else { return nil }

let name = text[text.startIndex ..< openParen]
guard !name.isEmpty, name.allSatisfy({ $0.isLetter || $0.isNumber || $0 == "_" })
else { return nil }

funcName = String(name)
argsString = String(text[text.index(after: openParen) ..< closeParen])
}

// Parse arguments
let arguments = parseArguments(argsString, funcName: funcName, tools: tools)

return ToolCall(function: .init(name: funcName, arguments: arguments))
}

Expand Down
28 changes: 28 additions & 0 deletions Tests/MLXLMTests/ToolTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,34 @@ struct ToolTests {
#expect(toolCall.function.arguments["timezone"] == .string("UTC"))
}

@Test("Test Pythonic Tool Call Parser - Nested Parentheses in Argument Value")
func testPythonicParserNestedParentheses() throws {
let parser = PythonicToolCallParser(
startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
let content =
"<|tool_call_start|>[run_script(code=\"response = requests.get('https://api.example.com/data')\")] <|tool_call_end|>"

let toolCall = try #require(parser.parse(content: content, tools: nil))

#expect(toolCall.function.name == "run_script")
#expect(
toolCall.function.arguments["code"]
== .string("response = requests.get('https://api.example.com/data')"))
}

@Test("Test Pythonic Tool Call Parser - Nested Parentheses Without Brackets")
func testPythonicParserNestedParenthesesNoBrackets() throws {
let parser = PythonicToolCallParser(
startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
let content =
"<|tool_call_start|>run_script(code=\"print('hello')\")<|tool_call_end|>"

let toolCall = try #require(parser.parse(content: content, tools: nil))

#expect(toolCall.function.name == "run_script")
#expect(toolCall.function.arguments["code"] == .string("print('hello')"))
}

@Test("Test Pythonic Tool Call Parser - No Arguments")
func testPythonicParserNoArguments() throws {
let parser = PythonicToolCallParser(
Expand Down
Loading