Skip to content

Commit c1e8e36

Browse files
committed
Add parser for GPT-OSS Harmony tool call format
1 parent b362c8a commit c1e8e36

File tree

6 files changed

+728
-15
lines changed

6 files changed

+728
-15
lines changed

Libraries/MLXLLM/Models/GPTOSS.swift

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Foundation
1111
import MLX
1212
import MLXLMCommon
1313
import MLXNN
14+
import Tokenizers
1415

1516
// MARK: - Configuration
1617

@@ -49,7 +50,7 @@ public struct GPTOSSConfiguration: Codable, Sendable {
4950
case layerTypes = "layer_types"
5051
}
5152

52-
public init(from decoder: Decoder) throws {
53+
public init(from decoder: Swift.Decoder) throws {
5354
let container = try decoder.container(keyedBy: CodingKeys.self)
5455
self.modelType = try container.decode(String.self, forKey: .modelType)
5556
self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
@@ -99,6 +100,123 @@ private let compiledSwiglu: @Sendable (MLXArray, MLXArray) -> MLXArray = compile
99100
swiglu(xLinear, xGlu)
100101
}
101102

103+
/// GPT-OSS specific history serializer.
104+
struct GPTOSSMessageGenerator: MessageGenerator {
105+
func generate(messages: [Chat.Message]) -> [Message] {
106+
messages.map { message in
107+
guard message.role == .assistant else {
108+
return defaultMessage(for: message)
109+
}
110+
111+
guard let toolCalls = parseStoredToolCalls(from: message.content) else {
112+
return defaultMessage(for: message)
113+
}
114+
115+
return [
116+
"role": "assistant",
117+
"tool_calls": toolCalls,
118+
]
119+
}
120+
}
121+
122+
func generate(message: Chat.Message) -> Message {
123+
defaultMessage(for: message)
124+
}
125+
126+
private func defaultMessage(for message: Chat.Message) -> Message {
127+
[
128+
"role": message.role.rawValue,
129+
"content": message.content,
130+
]
131+
}
132+
133+
private func parseStoredToolCalls(from content: String) -> [[String: any Sendable]]? {
134+
let lines =
135+
content
136+
.split(whereSeparator: \.isNewline)
137+
.map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }
138+
.filter { !$0.isEmpty }
139+
140+
guard !lines.isEmpty else { return nil }
141+
142+
var toolCalls: [[String: any Sendable]] = []
143+
for line in lines {
144+
guard
145+
let data = line.data(using: .utf8),
146+
let raw = try? JSONSerialization.jsonObject(with: data),
147+
let dict = raw as? [String: Any],
148+
let rawName = dict["name"] as? String,
149+
let argumentsRaw = dict["arguments"],
150+
let arguments = normalizedSendableJSON(from: argumentsRaw)
151+
else {
152+
return nil
153+
}
154+
155+
let toolName = normalizeToolName(rawName)
156+
guard !toolName.isEmpty else { return nil }
157+
158+
let function: [String: any Sendable] = [
159+
"name": toolName,
160+
"arguments": arguments,
161+
]
162+
let toolCall: [String: any Sendable] = [
163+
"type": "function",
164+
"function": function,
165+
]
166+
toolCalls.append(toolCall)
167+
}
168+
169+
return toolCalls.isEmpty ? nil : toolCalls
170+
}
171+
172+
private func normalizeToolName(_ name: String) -> String {
173+
let trimmed = name.trimmingCharacters(in: .whitespacesAndNewlines)
174+
if trimmed.hasPrefix("functions.") {
175+
return String(trimmed.dropFirst("functions.".count))
176+
}
177+
return trimmed
178+
}
179+
180+
private func normalizedSendableJSON(from any: Any) -> (any Sendable)? {
181+
if let dict = any as? [String: Any] {
182+
var output: [String: any Sendable] = [:]
183+
for (key, value) in dict {
184+
guard let normalized = normalizedSendableJSON(from: value) else { return nil }
185+
output[key] = normalized
186+
}
187+
return output
188+
}
189+
190+
if let array = any as? [Any] {
191+
var output: [any Sendable] = []
192+
output.reserveCapacity(array.count)
193+
for value in array {
194+
guard let normalized = normalizedSendableJSON(from: value) else { return nil }
195+
output.append(normalized)
196+
}
197+
return output
198+
}
199+
200+
if let string = any as? String { return string }
201+
if let bool = any as? Bool { return bool }
202+
if let int = any as? Int { return int }
203+
if let double = any as? Double { return double }
204+
205+
if let number = any as? NSNumber {
206+
if CFGetTypeID(number) == CFBooleanGetTypeID() {
207+
return number.boolValue
208+
}
209+
let doubleValue = number.doubleValue
210+
if floor(doubleValue) == doubleValue {
211+
return number.intValue
212+
}
213+
return doubleValue
214+
}
215+
216+
return nil
217+
}
218+
}
219+
102220
class SwiGLUSwitchGLU: Module {
103221
@ModuleInfo(key: "gate_proj") var gateProj: SwitchLinear
104222
@ModuleInfo(key: "up_proj") var upProj: SwitchLinear
@@ -526,6 +644,10 @@ public class GPTOSSModel: Module, LLMModel, KVCacheDimensionProvider {
526644
return finalWeights
527645
}
528646

647+
public func messageGenerator(tokenizer: any Tokenizer) -> any MessageGenerator {
648+
GPTOSSMessageGenerator()
649+
}
650+
529651
public func newCache(parameters: GenerateParameters?) -> [any KVCache] {
530652
var caches: [KVCache] = []
531653

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
/// Parser for GPT-OSS Harmony tool call format.
6+
/// Example:
7+
/// <|channel|>commentary to=get_weather
8+
/// <|message|>{"location": "Tokyo"}
9+
/// <|call|>
10+
public struct GPTOSSToolCallParser: ToolCallParser, Sendable {
11+
public let startTag: String? = Tag.channelTool
12+
public let startTags: [String] = [Tag.channelTool, Tag.assistantTool, Tag.assistantChannelTool]
13+
public let endTag: String? = "<|call|>"
14+
15+
private enum Tag {
16+
static let channel = "<|channel|>"
17+
static let message = "<|message|>"
18+
static let end = "<|end|>"
19+
static let call = "<|call|>"
20+
static let `return` = "<|return|>"
21+
static let constrain = "<|constrain|>"
22+
static let assistantTool = "<|start|>assistant to="
23+
static let assistantChannelTool = "<|start|>assistant<|channel|>commentary to="
24+
static let channelTool = "<|channel|>commentary to="
25+
}
26+
27+
public init() {}
28+
29+
public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
30+
extractSingle(from: content)
31+
}
32+
33+
public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
34+
extractAll(from: toolCallBuffer)
35+
}
36+
37+
private func extractAll(from text: String) -> [ToolCall] {
38+
var calls: [ToolCall] = []
39+
var cursor = text.startIndex
40+
41+
while cursor < text.endIndex {
42+
guard
43+
let messageRange = text.range(
44+
of: Tag.message, range: cursor ..< text.endIndex)
45+
else {
46+
break
47+
}
48+
49+
let rawHeader = String(text[cursor ..< messageRange.lowerBound])
50+
let boundary = firstBoundary(in: text, from: messageRange.upperBound)
51+
let payloadEnd = boundary?.lowerBound ?? text.endIndex
52+
let payload = String(text[messageRange.upperBound ..< payloadEnd])
53+
54+
if let function = parseFunction(header: rawHeader, payload: payload) {
55+
calls.append(ToolCall(function: function))
56+
}
57+
58+
if let boundary {
59+
// If the boundary is a new-call channel opener, keep the tag so the next
60+
// iteration's rawHeader includes it (parseFunction requires <|channel|>).
61+
cursor =
62+
String(text[boundary]) == Tag.channel
63+
? boundary.lowerBound : boundary.upperBound
64+
} else {
65+
cursor = text.endIndex
66+
}
67+
}
68+
return calls
69+
}
70+
71+
private func extractSingle(from chunk: String) -> ToolCall? {
72+
guard let messageRange = chunk.range(of: Tag.message) else { return nil }
73+
let rawChannelHeader = String(chunk[..<messageRange.lowerBound])
74+
let payload = String(chunk[messageRange.upperBound...])
75+
var cleanPayload = payload
76+
for tag in [Tag.end, Tag.call, Tag.return, Tag.channel] {
77+
if let tagRange = cleanPayload.range(of: tag) {
78+
cleanPayload = String(cleanPayload[..<tagRange.lowerBound])
79+
}
80+
}
81+
82+
guard let function = parseFunction(header: rawChannelHeader, payload: cleanPayload) else {
83+
return nil
84+
}
85+
return ToolCall(function: function)
86+
}
87+
88+
private func parseFunction(header: String, payload: String) -> ToolCall.Function? {
89+
let trimmedHeader = header.trimmingCharacters(in: .whitespacesAndNewlines)
90+
91+
guard let channelRange = trimmedHeader.range(of: Tag.channel, options: .backwards) else {
92+
return nil
93+
}
94+
95+
let channelHeader = String(trimmedHeader[channelRange.upperBound...]).trimmingCharacters(
96+
in: .whitespacesAndNewlines)
97+
guard channelHeader.hasPrefix("commentary") else { return nil }
98+
99+
let roleHeader = String(trimmedHeader[..<channelRange.lowerBound])
100+
let recipient = recipient(in: channelHeader) ?? recipient(in: roleHeader) ?? ""
101+
102+
let normalizedName = canonicalName(from: recipient)
103+
guard !normalizedName.isEmpty else { return nil }
104+
105+
let normalizedArguments = normalizedArgumentsJSON(from: payload)
106+
guard let argsDict = tryParseJSON(normalizedArguments) as? [String: any Sendable] else {
107+
return nil
108+
}
109+
110+
return ToolCall.Function(name: normalizedName, arguments: argsDict)
111+
}
112+
113+
private func recipient(in headerSection: String) -> String? {
114+
guard let toRange = headerSection.range(of: "to=") else { return nil }
115+
var suffix = String(headerSection[toRange.upperBound...])
116+
if let constrainRange = suffix.range(of: Tag.constrain) {
117+
suffix = String(suffix[..<constrainRange.lowerBound])
118+
}
119+
let recipient =
120+
suffix
121+
.trimmingCharacters(in: .whitespacesAndNewlines)
122+
.split(whereSeparator: { $0.isWhitespace || $0 == "<" || $0 == ">" })
123+
.first.map(String.init) ?? ""
124+
return recipient.isEmpty ? nil : recipient
125+
}
126+
127+
private func canonicalName(from rawName: String) -> String {
128+
let trimmed = rawName.trimmingCharacters(in: .whitespacesAndNewlines)
129+
guard !trimmed.isEmpty else { return "" }
130+
if trimmed.hasPrefix("functions.") {
131+
return String(trimmed.dropFirst("functions.".count))
132+
}
133+
return trimmed
134+
}
135+
136+
private func normalizedArgumentsJSON(from payload: String) -> String {
137+
let trimmed = payload.trimmingCharacters(in: .whitespacesAndNewlines)
138+
guard !trimmed.isEmpty else { return "{}" }
139+
140+
let unfenced = stripCodeFenceIfNeeded(from: trimmed)
141+
142+
if let unwrappedJSONString = unwrappedJSONString(from: unfenced) {
143+
return unwrappedJSONString
144+
}
145+
146+
if let extractedObject = extractedNestedJSONObject(from: unfenced) {
147+
return extractedObject
148+
}
149+
150+
return unfenced
151+
}
152+
153+
private func stripCodeFenceIfNeeded(from text: String) -> String {
154+
guard text.hasPrefix("```") else { return text }
155+
var normalized = text
156+
if let firstNewline = normalized.firstIndex(of: "\n") {
157+
normalized = String(normalized[normalized.index(after: firstNewline)...])
158+
}
159+
if let closingFence = normalized.range(of: "```", options: .backwards) {
160+
normalized = String(normalized[..<closingFence.lowerBound])
161+
}
162+
return normalized.trimmingCharacters(in: .whitespacesAndNewlines)
163+
}
164+
165+
private func unwrappedJSONString(from text: String) -> String? {
166+
guard let data = text.data(using: .utf8),
167+
let stringValue = try? JSONDecoder().decode(String.self, from: data)
168+
else {
169+
return nil
170+
}
171+
let normalized = stringValue.trimmingCharacters(in: .whitespacesAndNewlines)
172+
guard normalized.hasPrefix("{"), normalized.hasSuffix("}") else { return nil }
173+
return normalized
174+
}
175+
176+
private func extractedNestedJSONObject(from text: String) -> String? {
177+
guard let data = text.data(using: .utf8),
178+
let json = try? JSONSerialization.jsonObject(with: data),
179+
let dict = json as? [String: Any]
180+
else {
181+
return nil
182+
}
183+
184+
let wrapperKeys = ["arguments", "args", "input", "parameters", "kwargs"]
185+
for key in wrapperKeys {
186+
guard let nested = dict[key] else { continue }
187+
if let nestedDict = nested as? [String: Any],
188+
JSONSerialization.isValidJSONObject(nestedDict),
189+
let nestedData = try? JSONSerialization.data(
190+
withJSONObject: nestedDict, options: [.sortedKeys]),
191+
let nestedJSON = String(data: nestedData, encoding: .utf8)
192+
{
193+
return nestedJSON
194+
}
195+
if let nestedString = nested as? String {
196+
let normalized = nestedString.trimmingCharacters(in: .whitespacesAndNewlines)
197+
if normalized.hasPrefix("{"), normalized.hasSuffix("}") {
198+
return normalized
199+
}
200+
}
201+
}
202+
return nil
203+
}
204+
205+
private func firstBoundary(in text: String, from index: String.Index) -> Range<String.Index>? {
206+
guard index < text.endIndex else { return nil }
207+
let searchRange = index ..< text.endIndex
208+
209+
let tags = [Tag.end, Tag.call, Tag.return]
210+
var minRange: Range<String.Index>? = nil
211+
212+
for tag in tags {
213+
if let range = text.range(of: tag, range: searchRange) {
214+
if minRange == nil || range.lowerBound < minRange!.lowerBound {
215+
minRange = range
216+
}
217+
}
218+
}
219+
220+
if let channelRange = text.range(of: Tag.channel, range: searchRange) {
221+
let nextSearch = channelRange.upperBound ..< text.endIndex
222+
if text.range(of: Tag.message, range: nextSearch) != nil {
223+
if minRange == nil || channelRange.lowerBound < minRange!.lowerBound {
224+
minRange = channelRange
225+
}
226+
}
227+
}
228+
229+
return minRange
230+
}
231+
}

0 commit comments

Comments
 (0)