diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 379aa6a7..db7e8750 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -178,17 +178,74 @@ import Foundation includeSchemaInPrompt: Bool, options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { - // For now, only String is supported guard type == String.self else { fatalError("MLXLanguageModel only supports generating String content") } - // Streaming API in AnyLanguageModel currently yields once; return an empty snapshot - let empty = "" - return LanguageModelSession.ResponseStream( - content: empty as! Content, - rawContent: GeneratedContent(empty) - ) + let modelId = self.modelId + let hub = self.hub + let directory = self.directory + + let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let task = Task { @Sendable in + do { + let context: ModelContext + if let directory { + context = try await loadModel(directory: directory) + } else if let hub { + context = try await loadModel(hub: hub, id: modelId) + } else { + context = try await loadModel(id: modelId) + } + + let generateParameters = toGenerateParameters(options) + + var chat: [MLXLMCommon.Chat.Message] = [] + + if let instructionSegments = extractInstructionSegments(from: session) { + chat.append(convertSegmentsToMLXSystemMessage(instructionSegments)) + } + + let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description) + chat.append(convertSegmentsToMLXMessage(userSegments)) + + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil + ) + let lmInput = try await context.processor.prepare(input: userInput) + + let mlxStream = try MLXLMCommon.generate( + input: lmInput, + parameters: generateParameters, + context: context + ) + + var accumulatedText = "" + for await item in mlxStream { + if Task.isCancelled { break } + + switch item { + case .chunk(let text): + accumulatedText += text + let raw = GeneratedContent(accumulatedText) + let content: Content.PartiallyGenerated = (accumulatedText as! Content).asPartiallyGenerated() + continuation.yield(.init(content: content, rawContent: raw)) + case .info, .toolCall: + break + } + } + + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in task.cancel() } + } + + return LanguageModelSession.ResponseStream(stream: stream) } } diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 9f72715d..a66be321 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -29,7 +29,7 @@ import Testing return false }() - @Suite("MLXLanguageModel", .enabled(if: shouldRunMLXTests)) + @Suite("MLXLanguageModel", .enabled(if: shouldRunMLXTests), .serialized) struct MLXLanguageModelTests { // Qwen3-0.6B is a small model that supports tool calling let model = MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit") @@ -42,6 +42,19 @@ import Testing #expect(!response.content.isEmpty) } + @Test func streamingResponse() async throws { + let session = LanguageModelSession(model: model) + + let stream = session.streamResponse(to: "Count to 5") + var chunks: [String] = [] + + for try await response in stream { + chunks.append(response.content) + } + + #expect(!chunks.isEmpty) + } + @Test func withGenerationOptions() async throws { let session = LanguageModelSession(model: model)