Skip to content

Commit 877ee49

Browse files
noorbhatiamattt
andauthored
Pass instructions as system prompt for MLXLanguageModel (#48)
* Pass instructions as system prompt for MLX * Update test section with information about running MLX tests --------- Co-authored-by: Mattt Zmuda <[email protected]>
1 parent ad2cb1a commit 877ee49

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

README.md

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -630,28 +630,51 @@ swift test
630630

631631
Tests for different language model backends have varying requirements:
632632

633-
- **CoreML tests**: `swift test --traits CoreML` + `ENABLE_COREML_TESTS=1` + `HF_TOKEN` (downloads model from HuggingFace)
634-
- **MLX tests**: `swift test --traits MLX` + `ENABLE_MLX_TESTS=1` + `HF_TOKEN` (uses pre-defined model)
635-
- **Llama tests**: `swift test --traits Llama` + `LLAMA_MODEL_PATH` (points to local GGUF file)
636-
- **Anthropic tests**: `ANTHROPIC_API_KEY` (no traits needed)
637-
- **OpenAI tests**: `OPENAI_API_KEY` (no traits needed)
638-
- **Ollama tests**: No setup needed (skips in CI)
633+
| Backend | Traits | Environment Variables |
634+
|---------|--------|----------------------|
635+
| CoreML | `CoreML` | `HF_TOKEN` |
636+
| MLX | `MLX` | `HF_TOKEN` |
637+
| Llama | `Llama` | `LLAMA_MODEL_PATH` |
638+
| Anthropic || `ANTHROPIC_API_KEY` |
639+
| OpenAI || `OPENAI_API_KEY` |
640+
| Ollama |||
639641

640-
Example setup for all backends:
642+
Example setup for running multiple tests at once:
641643

642644
```bash
643-
# Environment variables
644-
export ENABLE_COREML_TESTS=1
645-
export ENABLE_MLX_TESTS=1
646645
export HF_TOKEN=your_huggingface_token
647646
export LLAMA_MODEL_PATH=/path/to/model.gguf
648647
export ANTHROPIC_API_KEY=your_anthropic_key
649648
export OPENAI_API_KEY=your_openai_key
650649

651-
# Run all tests with traits enabled
652-
swift test --traits CoreML,MLX,Llama
650+
swift test --traits CoreML,Llama
653651
```
654652

653+
> [!TIP]
654+
> Tests that perform generation are skipped in CI environments (when `CI` is set).
655+
> Override this by setting `ENABLE_COREML_TESTS=1` or `ENABLE_MLX_TESTS=1`.
656+
657+
> [!NOTE]
658+
> MLX tests must be run with `xcodebuild` rather than `swift test`
659+
> due to Metal library loading requirements.
660+
> Since `xcodebuild` doesn't support package traits directly,
661+
> you'll first need to update `Package.swift` to enable the MLX trait by default.
662+
>
663+
> ```diff
664+
> - .default(enabledTraits: []),
665+
> + .default(enabledTraits: ["MLX"]),
666+
> ```
667+
>
668+
> Pass environment variables with `TEST_RUNNER_` prefix:
669+
>
670+
> ```bash
671+
> export TEST_RUNNER_HF_TOKEN=your_huggingface_token
672+
> xcodebuild test \
673+
> -scheme AnyLanguageModel \
674+
> -destination 'platform=macOS' \
675+
> -only-testing:AnyLanguageModelTests/MLXLanguageModelTests
676+
> ```
677+
655678
## License
656679
657680
This project is available under the MIT license.

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,20 @@ import Foundation
8383
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
8484
let generateParameters = toGenerateParameters(options)
8585

86-
// Start with user prompt
86+
// Build chat history starting with system message if instructions are present
87+
var chat: [MLXLMCommon.Chat.Message] = []
88+
89+
// Add system message if instructions are present
90+
if let instructionSegments = extractInstructionSegments(from: session) {
91+
let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments)
92+
chat.append(systemMessage)
93+
}
94+
95+
// Add user prompt
8796
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
8897
let userMessage = convertSegmentsToMLXMessage(userSegments)
89-
var chat: [MLXLMCommon.Chat.Message] = [userMessage]
98+
chat.append(userMessage)
99+
90100
var allTextChunks: [String] = []
91101
var allEntries: [Transcript.Entry] = []
92102

@@ -211,6 +221,20 @@ import Foundation
211221
return [.text(.init(content: fallbackText))]
212222
}
213223

224+
private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? {
225+
// Prefer the first Transcript.Instructions entry if present
226+
for entry in session.transcript {
227+
if case .instructions(let i) = entry {
228+
return i.segments
229+
}
230+
}
231+
// Fallback to session.instructions
232+
if let instructions = session.instructions?.description, !instructions.isEmpty {
233+
return [.text(.init(content: instructions))]
234+
}
235+
return nil
236+
}
237+
214238
private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
215239
var textParts: [String] = []
216240
var images: [MLXLMCommon.UserInput.Image] = []
@@ -248,6 +272,43 @@ import Foundation
248272
return MLXLMCommon.Chat.Message(role: .user, content: content, images: images)
249273
}
250274

275+
private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
276+
var textParts: [String] = []
277+
var images: [MLXLMCommon.UserInput.Image] = []
278+
279+
for segment in segments {
280+
switch segment {
281+
case .text(let text):
282+
textParts.append(text.content)
283+
case .structure(let structured):
284+
textParts.append(structured.content.jsonString)
285+
case .image(let imageSegment):
286+
switch imageSegment.source {
287+
case .url(let url):
288+
images.append(.url(url))
289+
case .data(let data, _):
290+
#if canImport(UIKit)
291+
if let uiImage = UIKit.UIImage(data: data),
292+
let ciImage = CIImage(image: uiImage)
293+
{
294+
images.append(.ciImage(ciImage))
295+
}
296+
#elseif canImport(AppKit)
297+
if let nsImage = AppKit.NSImage(data: data),
298+
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
299+
{
300+
let ciImage = CIImage(cgImage: cgImage)
301+
images.append(.ciImage(ciImage))
302+
}
303+
#endif
304+
}
305+
}
306+
}
307+
308+
let content = textParts.joined(separator: "\n")
309+
return MLXLMCommon.Chat.Message(role: .system, content: content, images: images)
310+
}
311+
251312
// MARK: - Tool Conversion
252313

253314
private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec {

0 commit comments

Comments
 (0)