Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -618,28 +618,51 @@ swift test

Tests for different language model backends have varying requirements:

- **CoreML tests**: `swift test --traits CoreML` + `ENABLE_COREML_TESTS=1` + `HF_TOKEN` (downloads model from HuggingFace)
- **MLX tests**: `swift test --traits MLX` + `ENABLE_MLX_TESTS=1` + `HF_TOKEN` (uses pre-defined model)
- **Llama tests**: `swift test --traits Llama` + `LLAMA_MODEL_PATH` (points to local GGUF file)
- **Anthropic tests**: `ANTHROPIC_API_KEY` (no traits needed)
- **OpenAI tests**: `OPENAI_API_KEY` (no traits needed)
- **Ollama tests**: No setup needed (skips in CI)
| Backend | Traits | Environment Variables |
|---------|--------|----------------------|
| CoreML | `CoreML` | `HF_TOKEN` |
| MLX | `MLX` | `HF_TOKEN` |
| Llama | `Llama` | `LLAMA_MODEL_PATH` |
| Anthropic | — | `ANTHROPIC_API_KEY` |
| OpenAI | — | `OPENAI_API_KEY` |
| Ollama | — | — |

Example setup for all backends:
Example setup for running multiple tests at once:

```bash
# Environment variables
export ENABLE_COREML_TESTS=1
export ENABLE_MLX_TESTS=1
export HF_TOKEN=your_huggingface_token
export LLAMA_MODEL_PATH=/path/to/model.gguf
export ANTHROPIC_API_KEY=your_anthropic_key
export OPENAI_API_KEY=your_openai_key

# Run all tests with traits enabled
swift test --traits CoreML,MLX,Llama
swift test --traits CoreML,Llama
```

> [!TIP]
> Tests that perform generation are skipped in CI environments (when `CI` is set).
> Override this by setting `ENABLE_COREML_TESTS=1` or `ENABLE_MLX_TESTS=1`.

> [!NOTE]
> MLX tests must be run with `xcodebuild` rather than `swift test`
> due to Metal library loading requirements.
> Since `xcodebuild` doesn't support package traits directly,
> you'll first need to update `Package.swift` to enable the MLX trait by default.
>
> ```diff
> - .default(enabledTraits: []),
> + .default(enabledTraits: ["MLX"]),
> ```
>
> Pass environment variables with `TEST_RUNNER_` prefix:
>
> ```bash
> export TEST_RUNNER_HF_TOKEN=your_huggingface_token
> xcodebuild test \
> -scheme AnyLanguageModel \
> -destination 'platform=macOS' \
> -only-testing:AnyLanguageModelTests/MLXLanguageModelTests
> ```

## License

This project is available under the MIT license.
Expand Down
65 changes: 63 additions & 2 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,20 @@ import Foundation
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
let generateParameters = toGenerateParameters(options)

// Start with user prompt
// Build chat history starting with system message if instructions are present
var chat: [MLXLMCommon.Chat.Message] = []

// Add system message if instructions are present
if let instructionSegments = extractInstructionSegments(from: session) {
let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments)
chat.append(systemMessage)
}

// Add user prompt
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
let userMessage = convertSegmentsToMLXMessage(userSegments)
var chat: [MLXLMCommon.Chat.Message] = [userMessage]
chat.append(userMessage)

var allTextChunks: [String] = []
var allEntries: [Transcript.Entry] = []

Expand Down Expand Up @@ -211,6 +221,20 @@ import Foundation
return [.text(.init(content: fallbackText))]
}

private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? {
// Prefer the first Transcript.Instructions entry if present
for entry in session.transcript {
if case .instructions(let i) = entry {
return i.segments
}
}
// Fallback to session.instructions
if let instructions = session.instructions?.description, !instructions.isEmpty {
return [.text(.init(content: instructions))]
}
return nil
}

private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
var textParts: [String] = []
var images: [MLXLMCommon.UserInput.Image] = []
Expand Down Expand Up @@ -248,6 +272,43 @@ import Foundation
return MLXLMCommon.Chat.Message(role: .user, content: content, images: images)
}

private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
var textParts: [String] = []
var images: [MLXLMCommon.UserInput.Image] = []

for segment in segments {
switch segment {
case .text(let text):
textParts.append(text.content)
case .structure(let structured):
textParts.append(structured.content.jsonString)
case .image(let imageSegment):
switch imageSegment.source {
case .url(let url):
images.append(.url(url))
case .data(let data, _):
#if canImport(UIKit)
if let uiImage = UIKit.UIImage(data: data),
let ciImage = CIImage(image: uiImage)
{
images.append(.ciImage(ciImage))
}
#elseif canImport(AppKit)
if let nsImage = AppKit.NSImage(data: data),
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
{
let ciImage = CIImage(cgImage: cgImage)
images.append(.ciImage(ciImage))
}
#endif
}
}
}

let content = textParts.joined(separator: "\n")
return MLXLMCommon.Chat.Message(role: .system, content: content, images: images)
}

// MARK: - Tool Conversion

private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec {
Expand Down