Skip to content

Commit 43014bd

Browse files
committed
Pass instructions as system prompt for MLX
1 parent ab0db98 commit 43014bd

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,20 @@ import Foundation
6464
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
6565
let generateParameters = toGenerateParameters(options)
6666

67-
// Start with user prompt
67+
// Build chat history starting with system message if instructions are present
68+
var chat: [MLXLMCommon.Chat.Message] = []
69+
70+
// Add system message if instructions are present
71+
if let instructionSegments = extractInstructionSegments(from: session) {
72+
let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments)
73+
chat.append(systemMessage)
74+
}
75+
76+
// Add user prompt
6877
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
6978
let userMessage = convertSegmentsToMLXMessage(userSegments)
70-
var chat: [MLXLMCommon.Chat.Message] = [userMessage]
79+
chat.append(userMessage)
80+
7181
var allTextChunks: [String] = []
7282
var allEntries: [Transcript.Entry] = []
7383

@@ -192,6 +202,20 @@ import Foundation
192202
return [.text(.init(content: fallbackText))]
193203
}
194204

205+
private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? {
206+
// Prefer the first Transcript.Instructions entry if present
207+
for entry in session.transcript {
208+
if case .instructions(let i) = entry {
209+
return i.segments
210+
}
211+
}
212+
// Fallback to session.instructions
213+
if let instructions = session.instructions?.description, !instructions.isEmpty {
214+
return [.text(.init(content: instructions))]
215+
}
216+
return nil
217+
}
218+
195219
private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
196220
var textParts: [String] = []
197221
var images: [MLXLMCommon.UserInput.Image] = []
@@ -229,6 +253,43 @@ import Foundation
229253
return MLXLMCommon.Chat.Message(role: .user, content: content, images: images)
230254
}
231255

256+
private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
257+
var textParts: [String] = []
258+
var images: [MLXLMCommon.UserInput.Image] = []
259+
260+
for segment in segments {
261+
switch segment {
262+
case .text(let text):
263+
textParts.append(text.content)
264+
case .structure(let structured):
265+
textParts.append(structured.content.jsonString)
266+
case .image(let imageSegment):
267+
switch imageSegment.source {
268+
case .url(let url):
269+
images.append(.url(url))
270+
case .data(let data, _):
271+
#if canImport(UIKit)
272+
if let uiImage = UIKit.UIImage(data: data),
273+
let ciImage = CIImage(image: uiImage)
274+
{
275+
images.append(.ciImage(ciImage))
276+
}
277+
#elseif canImport(AppKit)
278+
if let nsImage = AppKit.NSImage(data: data),
279+
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
280+
{
281+
let ciImage = CIImage(cgImage: cgImage)
282+
images.append(.ciImage(ciImage))
283+
}
284+
#endif
285+
}
286+
}
287+
}
288+
289+
let content = textParts.joined(separator: "\n")
290+
return MLXLMCommon.Chat.Message(role: .system, content: content, images: images)
291+
}
292+
232293
// MARK: - Tool Conversion
233294

234295
private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec {

0 commit comments

Comments
 (0)