Skip to content

Commit 1a04460

Browse files
committed
Add parser for GPT-OSS Harmony tool call format
1 parent b362c8a commit 1a04460

File tree

4 files changed

+478
-10
lines changed

4 files changed

+478
-10
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
/// Parser for GPT-OSS Harmony tool call format.
6+
/// Example:
7+
/// <|channel|>
8+
/// commentary to=get_weather
9+
/// <|message|>
10+
/// {"location": "Tokyo"}
11+
/// <|call|>
12+
public struct GPTOSSToolCallParser: ToolCallParser, Sendable {
13+
public let startTag: String? = Tag.channelTool
14+
public let startTags: [String] = [Tag.channelTool, Tag.assistantTool]
15+
public let endTag: String? = "<|call|>"
16+
17+
private enum Tag {
18+
static let channel = "<|channel|>"
19+
static let message = "<|message|>"
20+
static let end = "<|end|>"
21+
static let call = "<|call|>"
22+
static let `return` = "<|return|>"
23+
static let constrain = "<|constrain|>"
24+
static let assistantTool = "<|start|>assistant to="
25+
static let channelTool = "<|channel|>commentary to="
26+
}
27+
28+
public init() {}
29+
30+
public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
31+
return extractSingle(from: content)
32+
}
33+
34+
public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
35+
return extractAll(from: toolCallBuffer)
36+
}
37+
38+
private func extractAll(from text: String) -> [ToolCall] {
39+
var calls: [ToolCall] = []
40+
var cursor = text.startIndex
41+
42+
while cursor < text.endIndex {
43+
guard
44+
let messageRange = text.range(
45+
of: Tag.message, range: cursor ..< text.endIndex)
46+
else {
47+
break
48+
}
49+
50+
let rawHeader = String(text[cursor ..< messageRange.lowerBound])
51+
let boundary = firstBoundary(in: text, from: messageRange.upperBound)
52+
let payloadEnd = boundary?.lowerBound ?? text.endIndex
53+
let payload = String(text[messageRange.upperBound ..< payloadEnd])
54+
55+
if let function = parseFunction(header: rawHeader, payload: payload) {
56+
calls.append(ToolCall(function: function))
57+
}
58+
59+
if let boundary {
60+
cursor = boundary.upperBound
61+
} else {
62+
cursor = text.endIndex
63+
}
64+
}
65+
return calls
66+
}
67+
68+
private func extractSingle(from chunk: String) -> ToolCall? {
69+
guard let messageRange = chunk.range(of: Tag.message) else { return nil }
70+
let rawChannelHeader = String(chunk[..<messageRange.lowerBound])
71+
let payload = String(chunk[messageRange.upperBound...])
72+
var cleanPayload = payload
73+
for tag in [Tag.end, Tag.call, Tag.return, Tag.channel] {
74+
if let tagRange = cleanPayload.range(of: tag) {
75+
cleanPayload = String(cleanPayload[..<tagRange.lowerBound])
76+
}
77+
}
78+
79+
guard let function = parseFunction(header: rawChannelHeader, payload: cleanPayload) else {
80+
return nil
81+
}
82+
return ToolCall(function: function)
83+
}
84+
85+
private func parseFunction(header: String, payload: String) -> ToolCall.Function? {
86+
let trimmedHeader = header.trimmingCharacters(in: .whitespacesAndNewlines)
87+
88+
guard let channelRange = trimmedHeader.range(of: Tag.channel, options: .backwards) else {
89+
return nil
90+
}
91+
92+
let channelHeader = String(trimmedHeader[channelRange.upperBound...]).trimmingCharacters(
93+
in: .whitespacesAndNewlines)
94+
guard channelHeader.hasPrefix("commentary") else { return nil }
95+
96+
let roleHeader = String(trimmedHeader[..<channelRange.lowerBound])
97+
let recipient = recipient(in: channelHeader) ?? recipient(in: roleHeader) ?? ""
98+
99+
let normalizedName = canonicalName(from: recipient)
100+
guard !normalizedName.isEmpty else { return nil }
101+
102+
let normalizedArguments = normalizedArgumentsJSON(from: payload)
103+
guard let argsDict = tryParseJSON(normalizedArguments) as? [String: any Sendable] else {
104+
return nil
105+
}
106+
107+
return ToolCall.Function(name: normalizedName, arguments: argsDict)
108+
}
109+
110+
private func recipient(in headerSection: String) -> String? {
111+
guard let toRange = headerSection.range(of: "to=") else { return nil }
112+
var suffix = String(headerSection[toRange.upperBound...])
113+
if let constrainRange = suffix.range(of: Tag.constrain) {
114+
suffix = String(suffix[..<constrainRange.lowerBound])
115+
}
116+
let recipient =
117+
suffix
118+
.trimmingCharacters(in: .whitespacesAndNewlines)
119+
.split(whereSeparator: { $0.isWhitespace || $0 == "<" || $0 == ">" })
120+
.first.map(String.init) ?? ""
121+
return recipient.isEmpty ? nil : recipient
122+
}
123+
124+
private func canonicalName(from rawName: String) -> String {
125+
let trimmed = rawName.trimmingCharacters(in: .whitespacesAndNewlines)
126+
guard !trimmed.isEmpty else { return "" }
127+
if trimmed.hasPrefix("functions.") {
128+
return String(trimmed.dropFirst("functions.".count))
129+
}
130+
return trimmed
131+
}
132+
133+
private func normalizedArgumentsJSON(from payload: String) -> String {
134+
let trimmed = payload.trimmingCharacters(in: .whitespacesAndNewlines)
135+
guard !trimmed.isEmpty else { return "{}" }
136+
137+
let unfenced = stripCodeFenceIfNeeded(from: trimmed)
138+
139+
if let unwrappedJSONString = unwrappedJSONString(from: unfenced) {
140+
return unwrappedJSONString
141+
}
142+
143+
if let extractedObject = extractedNestedJSONObject(from: unfenced) {
144+
return extractedObject
145+
}
146+
147+
return unfenced
148+
}
149+
150+
private func stripCodeFenceIfNeeded(from text: String) -> String {
151+
guard text.hasPrefix("```") else { return text }
152+
var normalized = text
153+
if let firstNewline = normalized.firstIndex(of: "\n") {
154+
normalized = String(normalized[normalized.index(after: firstNewline)...])
155+
}
156+
if let closingFence = normalized.range(of: "```", options: .backwards) {
157+
normalized = String(normalized[..<closingFence.lowerBound])
158+
}
159+
return normalized.trimmingCharacters(in: .whitespacesAndNewlines)
160+
}
161+
162+
private func unwrappedJSONString(from text: String) -> String? {
163+
guard let data = text.data(using: .utf8),
164+
let stringValue = try? JSONDecoder().decode(String.self, from: data)
165+
else {
166+
return nil
167+
}
168+
let normalized = stringValue.trimmingCharacters(in: .whitespacesAndNewlines)
169+
guard normalized.hasPrefix("{"), normalized.hasSuffix("}") else { return nil }
170+
return normalized
171+
}
172+
173+
private func extractedNestedJSONObject(from text: String) -> String? {
174+
guard let data = text.data(using: .utf8),
175+
let json = try? JSONSerialization.jsonObject(with: data),
176+
let dict = json as? [String: Any]
177+
else {
178+
return nil
179+
}
180+
181+
let wrapperKeys = ["arguments", "args", "input", "parameters", "kwargs"]
182+
for key in wrapperKeys {
183+
guard let nested = dict[key] else { continue }
184+
if let nestedDict = nested as? [String: Any],
185+
JSONSerialization.isValidJSONObject(nestedDict),
186+
let nestedData = try? JSONSerialization.data(
187+
withJSONObject: nestedDict, options: [.sortedKeys]),
188+
let nestedJSON = String(data: nestedData, encoding: .utf8)
189+
{
190+
return nestedJSON
191+
}
192+
if let nestedString = nested as? String {
193+
let normalized = nestedString.trimmingCharacters(in: .whitespacesAndNewlines)
194+
if normalized.hasPrefix("{"), normalized.hasSuffix("}") {
195+
return normalized
196+
}
197+
}
198+
}
199+
return nil
200+
}
201+
202+
private func firstBoundary(in text: String, from index: String.Index) -> Range<String.Index>? {
203+
guard index < text.endIndex else { return nil }
204+
let searchRange = index ..< text.endIndex
205+
206+
let tags = [Tag.end, Tag.call, Tag.return]
207+
var minRange: Range<String.Index>? = nil
208+
209+
for tag in tags {
210+
if let range = text.range(of: tag, range: searchRange) {
211+
if minRange == nil || range.lowerBound < minRange!.lowerBound {
212+
minRange = range
213+
}
214+
}
215+
}
216+
217+
if let channelRange = text.range(of: Tag.channel, range: searchRange) {
218+
let nextSearch = channelRange.upperBound ..< text.endIndex
219+
if text.range(of: Tag.message, range: nextSearch) != nil {
220+
if minRange == nil || channelRange.lowerBound < minRange!.lowerBound {
221+
minRange = channelRange
222+
}
223+
}
224+
}
225+
226+
return minRange
227+
}
228+
}

Libraries/MLXLMCommon/Tool/ToolCallFormat.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ public protocol ToolCallParser: Sendable {
1515
/// Returns `nil` for inline formats that don't use wrapper tags.
1616
var startTag: String? { get }
1717

18+
/// Optional additional start tags for formats with multiple valid headers.
19+
/// Defaults to `[startTag]` when a start tag exists.
20+
var startTags: [String] { get }
21+
1822
/// The end tag that indicates a tool call has ended.
1923
/// Returns `nil` for inline formats that don't use wrapper tags.
2024
var endTag: String? { get }
@@ -35,6 +39,11 @@ public protocol ToolCallParser: Sendable {
3539
}
3640

3741
extension ToolCallParser {
42+
public var startTags: [String] {
43+
guard let startTag else { return [] }
44+
return [startTag]
45+
}
46+
3847
public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
3948
if let startTag {
4049
return
@@ -94,6 +103,10 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
94103
/// Example: `[TOOL_CALLS]get_weather [ARGS]{"location": "Tokyo"}`
95104
case mistral
96105

106+
/// GPT-OSS Harmony tool call format.
107+
/// Example: `<|channel|>commentary to=get_weather<|message|>{"location": "Tokyo"}<|call|>`
108+
case gptOSS = "gpt_oss"
109+
97110
// MARK: - Factory Methods
98111

99112
/// Create the appropriate parser for this format.
@@ -117,6 +130,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
117130
return MiniMaxM2ToolCallParser()
118131
case .mistral:
119132
return MistralToolCallParser()
133+
case .gptOSS:
134+
return GPTOSSToolCallParser()
120135
}
121136
}
122137

@@ -150,6 +165,11 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
150165
return .mistral
151166
}
152167

168+
// GPT-OSS family
169+
if type.hasPrefix("gpt_oss") || type == "gptoss" {
170+
return .gptOSS
171+
}
172+
153173
return nil
154174
}
155175
}

Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ public class ToolCallProcessor {
5757

5858
/// Whether this processor uses inline format (no start tag).
5959
private var isInlineFormat: Bool {
60-
parser.startTag == nil
60+
parser.startTags.isEmpty
6161
}
6262

6363
/// The first character of the start tag for quick detection.
6464
private var startTagFirstChar: Character? {
65-
parser.startTag?.first
65+
parser.startTags.first?.first
6666
}
6767

6868
// MARK: - Public Methods
@@ -116,8 +116,8 @@ public class ToolCallProcessor {
116116

117117
/// Process chunk for tagged formats.
118118
private func processTaggedChunk(_ chunk: String) -> String? {
119-
guard let startTag = parser.startTag,
120-
let startChar = startTagFirstChar
119+
let startTags = parser.startTags
120+
guard !startTags.isEmpty, let startChar = startTagFirstChar
121121
else {
122122
return chunk
123123
}
@@ -134,13 +134,18 @@ public class ToolCallProcessor {
134134
// Change state to potential tool call
135135
state = .potentialToolCall
136136

137-
leadingToken = separateToken(
138-
from: &toolCallBuffer, separator: String(startChar), returnLeading: true)
137+
if let startRange = firstStartTagRange(in: toolCallBuffer, tags: startTags) {
138+
leadingToken = String(toolCallBuffer[..<startRange.lowerBound])
139+
toolCallBuffer = String(toolCallBuffer[startRange.lowerBound...])
140+
} else {
141+
leadingToken = separateToken(
142+
from: &toolCallBuffer, separator: String(startChar), returnLeading: true)
143+
}
139144

140145
fallthrough
141146
case .potentialToolCall:
142-
if partialMatch(buffer: toolCallBuffer, tag: startTag) {
143-
if toolCallBuffer.starts(with: startTag) {
147+
if partialMatch(buffer: toolCallBuffer, tags: startTags) {
148+
if startsWithAnyStartTag(buffer: toolCallBuffer, tags: startTags) {
144149
state = .collectingToolCall
145150
fallthrough
146151
} else {
@@ -172,13 +177,26 @@ public class ToolCallProcessor {
172177
toolCallBuffer = ""
173178

174179
// If the token contains the start character, there may be more tool calls to come
180+
let trailingOutput: String?
175181
if let trailingToken, let startChar = startTagFirstChar,
176182
trailingToken.contains(startChar)
177183
{
178-
return processChunk(trailingToken)
184+
trailingOutput = processChunk(trailingToken)
179185
} else {
180186
// Otherwise, return the collected token, or nil if it's empty
181-
return trailingToken?.isEmpty ?? true ? nil : trailingToken
187+
trailingOutput = trailingToken?.isEmpty ?? true ? nil : trailingToken
188+
}
189+
190+
let prefix = leadingToken ?? ""
191+
switch (prefix.isEmpty, trailingOutput) {
192+
case (true, .none):
193+
return nil
194+
case (true, .some(let output)):
195+
return output
196+
case (false, .none):
197+
return prefix
198+
case (false, .some(let output)):
199+
return prefix + output
182200
}
183201
} else {
184202
return nil
@@ -218,4 +236,23 @@ public class ToolCallProcessor {
218236

219237
return true
220238
}
239+
240+
private func partialMatch(buffer: String, tags: [String]) -> Bool {
241+
tags.contains { partialMatch(buffer: buffer, tag: $0) }
242+
}
243+
244+
private func startsWithAnyStartTag(buffer: String, tags: [String]) -> Bool {
245+
tags.contains { buffer.starts(with: $0) }
246+
}
247+
248+
private func firstStartTagRange(in buffer: String, tags: [String]) -> Range<String.Index>? {
249+
var earliest: Range<String.Index>? = nil
250+
for tag in tags {
251+
guard let range = buffer.range(of: tag) else { continue }
252+
if earliest == nil || range.lowerBound < earliest!.lowerBound {
253+
earliest = range
254+
}
255+
}
256+
return earliest
257+
}
221258
}

0 commit comments

Comments
 (0)