Skip to content

Commit 1500acc

Browse files
committed
Pass instructions as system prompt for MLX
1 parent 4be4168 commit 1500acc

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,20 @@ import Foundation
8383
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
8484
let generateParameters = toGenerateParameters(options)
8585

86-
// Start with user prompt
86+
// Build chat history starting with system message if instructions are present
87+
var chat: [MLXLMCommon.Chat.Message] = []
88+
89+
// Add system message if instructions are present
90+
if let instructionSegments = extractInstructionSegments(from: session) {
91+
let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments)
92+
chat.append(systemMessage)
93+
}
94+
95+
// Add user prompt
8796
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
8897
let userMessage = convertSegmentsToMLXMessage(userSegments)
89-
var chat: [MLXLMCommon.Chat.Message] = [userMessage]
98+
chat.append(userMessage)
99+
90100
var allTextChunks: [String] = []
91101
var allEntries: [Transcript.Entry] = []
92102

@@ -211,6 +221,20 @@ import Foundation
211221
return [.text(.init(content: fallbackText))]
212222
}
213223

224+
private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? {
225+
// Prefer the first Transcript.Instructions entry if present
226+
for entry in session.transcript {
227+
if case .instructions(let i) = entry {
228+
return i.segments
229+
}
230+
}
231+
// Fallback to session.instructions
232+
if let instructions = session.instructions?.description, !instructions.isEmpty {
233+
return [.text(.init(content: instructions))]
234+
}
235+
return nil
236+
}
237+
214238
private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
215239
var textParts: [String] = []
216240
var images: [MLXLMCommon.UserInput.Image] = []
@@ -248,6 +272,43 @@ import Foundation
248272
return MLXLMCommon.Chat.Message(role: .user, content: content, images: images)
249273
}
250274

275+
private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
276+
var textParts: [String] = []
277+
var images: [MLXLMCommon.UserInput.Image] = []
278+
279+
for segment in segments {
280+
switch segment {
281+
case .text(let text):
282+
textParts.append(text.content)
283+
case .structure(let structured):
284+
textParts.append(structured.content.jsonString)
285+
case .image(let imageSegment):
286+
switch imageSegment.source {
287+
case .url(let url):
288+
images.append(.url(url))
289+
case .data(let data, _):
290+
#if canImport(UIKit)
291+
if let uiImage = UIKit.UIImage(data: data),
292+
let ciImage = CIImage(image: uiImage)
293+
{
294+
images.append(.ciImage(ciImage))
295+
}
296+
#elseif canImport(AppKit)
297+
if let nsImage = AppKit.NSImage(data: data),
298+
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
299+
{
300+
let ciImage = CIImage(cgImage: cgImage)
301+
images.append(.ciImage(ciImage))
302+
}
303+
#endif
304+
}
305+
}
306+
}
307+
308+
let content = textParts.joined(separator: "\n")
309+
return MLXLMCommon.Chat.Message(role: .system, content: content, images: images)
310+
}
311+
251312
// MARK: - Tool Conversion
252313

253314
private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec {

0 commit comments

Comments
 (0)