Skip to content

Commit 9700936

Browse files
noorbhatiamattt
andauthored
Refactor message history construction from Transcription in MLXLanguageModel (#66)
* Refactor chat history construction to utilize full transcript conversion in MLXLanguageModel * Rename convertSegmentsToMLXMessage to convertSegmentsToMLXUserMessage * DRY up creation of MLX messages from segments * Rename segmentToText to extractText and refactor makeMLXChatMessage to use it --------- Co-authored-by: Mattt Zmuda <[email protected]>
1 parent 97db31f commit 9700936

File tree

1 file changed

+65
-71
lines changed

1 file changed

+65
-71
lines changed

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 65 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,8 @@ import Foundation
8383
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
8484
let generateParameters = toGenerateParameters(options)
8585

86-
// Build chat history starting with system message if instructions are present
87-
var chat: [MLXLMCommon.Chat.Message] = []
88-
89-
// Add system message if instructions are present
90-
if let instructionSegments = extractInstructionSegments(from: session) {
91-
let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments)
92-
chat.append(systemMessage)
93-
}
94-
95-
// Add user prompt
96-
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
97-
let userMessage = convertSegmentsToMLXMessage(userSegments)
98-
chat.append(userMessage)
86+
// Build chat history from full transcript
87+
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
9988

10089
var allTextChunks: [String] = []
10190
var allEntries: [Transcript.Entry] = []
@@ -208,80 +197,80 @@ import Foundation
208197
)
209198
}
210199

211-
// MARK: - Segment Extraction
200+
// MARK: - Transcript Conversion
212201

213-
private func extractPromptSegments(from session: LanguageModelSession, fallbackText: String) -> [Transcript.Segment]
214-
{
215-
// Prefer the most recent Transcript.Prompt entry if present
216-
for entry in session.transcript.reversed() {
217-
if case .prompt(let p) = entry {
218-
return p.segments
219-
}
202+
private func convertTranscriptToMLXChat(
203+
session: LanguageModelSession,
204+
fallbackPrompt: String
205+
) -> [MLXLMCommon.Chat.Message] {
206+
var chat: [MLXLMCommon.Chat.Message] = []
207+
208+
// Check if instructions are already in transcript
209+
let hasInstructionsInTranscript = session.transcript.contains {
210+
if case .instructions = $0 { return true }
211+
return false
220212
}
221-
return [.text(.init(content: fallbackText))]
222-
}
223213

224-
private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? {
225-
// Prefer the first Transcript.Instructions entry if present
214+
// Add instructions from session if present and not in transcript
215+
if !hasInstructionsInTranscript,
216+
let instructions = session.instructions?.description,
217+
!instructions.isEmpty
218+
{
219+
chat.append(.init(role: .system, content: instructions))
220+
}
221+
222+
// Convert each transcript entry
226223
for entry in session.transcript {
227-
if case .instructions(let i) = entry {
228-
return i.segments
224+
switch entry {
225+
case .instructions(let instr):
226+
chat.append(makeMLXChatMessage(from: instr.segments, role: .system))
227+
228+
case .prompt(let prompt):
229+
chat.append(makeMLXChatMessage(from: prompt.segments, role: .user))
230+
231+
case .response(let response):
232+
let content = response.segments.map { extractText(from: $0) }.joined(separator: "\n")
233+
chat.append(.assistant(content))
234+
235+
case .toolCalls:
236+
// Tool calls are handled inline during generation loop
237+
break
238+
239+
case .toolOutput(let toolOutput):
240+
let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n")
241+
chat.append(.tool(content))
229242
}
230243
}
231-
// Fallback to session.instructions
232-
if let instructions = session.instructions?.description, !instructions.isEmpty {
233-
return [.text(.init(content: instructions))]
244+
245+
// If no user message in transcript, add fallback prompt
246+
let hasUserMessage = chat.contains { $0.role == .user }
247+
if !hasUserMessage {
248+
chat.append(.init(role: .user, content: fallbackPrompt))
234249
}
235-
return nil
236-
}
237250

238-
private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
239-
var textParts: [String] = []
240-
var images: [MLXLMCommon.UserInput.Image] = []
251+
return chat
252+
}
241253

242-
for segment in segments {
243-
switch segment {
244-
case .text(let text):
245-
textParts.append(text.content)
246-
case .structure(let structured):
247-
textParts.append(structured.content.jsonString)
248-
case .image(let imageSegment):
249-
switch imageSegment.source {
250-
case .url(let url):
251-
images.append(.url(url))
252-
case .data(let data, _):
253-
#if canImport(UIKit)
254-
if let uiImage = UIKit.UIImage(data: data),
255-
let ciImage = CIImage(image: uiImage)
256-
{
257-
images.append(.ciImage(ciImage))
258-
}
259-
#elseif canImport(AppKit)
260-
if let nsImage = AppKit.NSImage(data: data),
261-
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
262-
{
263-
let ciImage = CIImage(cgImage: cgImage)
264-
images.append(.ciImage(ciImage))
265-
}
266-
#endif
267-
}
268-
}
254+
private func extractText(from segment: Transcript.Segment) -> String {
255+
switch segment {
256+
case .text(let text):
257+
return text.content
258+
case .structure(let structured):
259+
return structured.content.jsonString
260+
case .image:
261+
return ""
269262
}
270-
271-
let content = textParts.joined(separator: "\n")
272-
return MLXLMCommon.Chat.Message(role: .user, content: content, images: images)
273263
}
274264

275-
private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
265+
private func makeMLXChatMessage(
266+
from segments: [Transcript.Segment],
267+
role: MLXLMCommon.Chat.Message.Role
268+
) -> MLXLMCommon.Chat.Message {
276269
var textParts: [String] = []
277270
var images: [MLXLMCommon.UserInput.Image] = []
278271

279272
for segment in segments {
280273
switch segment {
281-
case .text(let text):
282-
textParts.append(text.content)
283-
case .structure(let structured):
284-
textParts.append(structured.content.jsonString)
285274
case .image(let imageSegment):
286275
switch imageSegment.source {
287276
case .url(let url):
@@ -302,11 +291,16 @@ import Foundation
302291
}
303292
#endif
304293
}
294+
default:
295+
let text = extractText(from: segment)
296+
if !text.isEmpty {
297+
textParts.append(text)
298+
}
305299
}
306300
}
307301

308302
let content = textParts.joined(separator: "\n")
309-
return MLXLMCommon.Chat.Message(role: .system, content: content, images: images)
303+
return MLXLMCommon.Chat.Message(role: role, content: content, images: images)
310304
}
311305

312306
// MARK: - Tool Conversion

0 commit comments

Comments
 (0)