diff --git a/Sources/AnyLanguageModel/GenerationGuide.swift b/Sources/AnyLanguageModel/GenerationGuide.swift index 35b4c49a..c9ab16f9 100644 --- a/Sources/AnyLanguageModel/GenerationGuide.swift +++ b/Sources/AnyLanguageModel/GenerationGuide.swift @@ -2,7 +2,24 @@ import struct Foundation.Decimal import class Foundation.NSDecimalNumber /// Guides that control how values are generated. -public struct GenerationGuide {} +public struct GenerationGuide: Sendable { + package var minimumCount: Int? + package var maximumCount: Int? + package var minimum: Double? + package var maximum: Double? + + public init() {} + + package init(minimumCount: Int?, maximumCount: Int?) { + self.minimumCount = minimumCount + self.maximumCount = maximumCount + } + + package init(minimum: Double?, maximum: Double?) { + self.minimum = minimum + self.maximum = maximum + } +} // MARK: - String Guides @@ -45,7 +62,7 @@ extension GenerationGuide where Value == Int { /// } /// ``` public static func minimum(_ value: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: Double(value), maximum: nil) } /// Enforces a maximum value. @@ -65,7 +82,7 @@ extension GenerationGuide where Value == Int { /// } /// ``` public static func maximum(_ value: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: nil, maximum: Double(value)) } /// Enforces values fall within a range. @@ -85,7 +102,7 @@ extension GenerationGuide where Value == Int { /// } /// ``` public static func range(_ range: ClosedRange) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: Double(range.lowerBound), maximum: Double(range.upperBound)) } } @@ -144,18 +161,18 @@ extension GenerationGuide where Value == Double { /// Enforces a minimum value. /// The bounds are inclusive. public static func minimum(_ value: Double) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: value, maximum: nil) } /// Enforces a maximum value. /// The bounds are inclusive. public static func maximum(_ value: Double) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: nil, maximum: value) } /// Enforces values fall within a range. public static func range(_ range: ClosedRange) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: range.lowerBound, maximum: range.upperBound) } } @@ -168,7 +185,7 @@ extension GenerationGuide { /// The bounds are inclusive. public static func minimumCount(_ count: Int) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: count, maximumCount: nil) } /// Enforces a maximum number of elements in the array. @@ -176,25 +193,23 @@ extension GenerationGuide { /// The bounds are inclusive. public static func maximumCount(_ count: Int) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: nil, maximumCount: count) } /// Enforces that the number of elements in the array fall within a closed range. public static func count(_ range: ClosedRange) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: range.lowerBound, maximumCount: range.upperBound) } /// Enforces that the array has exactly a certain number elements. public static func count(_ count: Int) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: count, maximumCount: count) } /// Enforces a guide on the elements within the array. - public static func element(_ guide: GenerationGuide) -> GenerationGuide< - [Element] - > + public static func element(_ guide: GenerationGuide) -> GenerationGuide<[Element]> where Value == [Element] { GenerationGuide<[Element]>() } @@ -210,7 +225,7 @@ extension GenerationGuide where Value == [Never] { /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.minimumCount(_:)` on your own. public static func minimumCount(_ count: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: count, maximumCount: nil) } /// Enforces a maximum number of elements in the array. @@ -219,20 +234,20 @@ extension GenerationGuide where Value == [Never] { /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.maximumCount(_:)` on your own. public static func maximumCount(_ count: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: nil, maximumCount: count) } /// Enforces that the number of elements in the array fall within a closed range. /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own. public static func count(_ range: ClosedRange) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: range.lowerBound, maximumCount: range.upperBound) } /// Enforces that the array has exactly a certain number elements. /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own. public static func count(_ count: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: count, maximumCount: count) } } diff --git a/Sources/AnyLanguageModel/GenerationSchema.swift b/Sources/AnyLanguageModel/GenerationSchema.swift index a8065d96..69fd79d6 100644 --- a/Sources/AnyLanguageModel/GenerationSchema.swift +++ b/Sources/AnyLanguageModel/GenerationSchema.swift @@ -160,6 +160,16 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible ) } } + + var nodeDescription: String? { + switch self { + case .object(let node): node.description + case .array(let node): node.description + case .string(let node): node.description + case .number(let node): node.description + case .boolean, .anyOf, .ref: nil + } + } } struct ObjectNode: Sendable, Codable { @@ -204,7 +214,7 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible } let root: Node - private var defs: [String: Node] + var defs: [String: Node] /// A string representation of the debug description. /// @@ -504,12 +514,32 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible } private static func nodesEqual(_ a: Node, _ b: Node) -> Bool { - // Simple structural equality - could be enhanced switch (a, b) { case (.boolean, .boolean): return true case (.ref(let aName), .ref(let bName)): return aName == bName + case (.string(let aString), .string(let bString)): + return aString.pattern == bString.pattern + && aString.enumChoices == bString.enumChoices + case (.number(let aNumber), .number(let bNumber)): + return aNumber.integerOnly == bNumber.integerOnly + && aNumber.minimum == bNumber.minimum + && aNumber.maximum == bNumber.maximum + case (.array(let aArray), .array(let bArray)): + return aArray.minItems == bArray.minItems + && aArray.maxItems == bArray.maxItems + && nodesEqual(aArray.items, bArray.items) + case (.object(let aObject), .object(let bObject)): + return aObject.required == bObject.required + && aObject.properties.keys == bObject.properties.keys + && aObject.properties.allSatisfy { key, aNode in + guard let bNode = bObject.properties[key] else { return false } + return nodesEqual(aNode, bNode) + } + case (.anyOf(let aNodes), .anyOf(let bNodes)): + return aNodes.count == bNodes.count + && zip(aNodes, bNodes).allSatisfy(nodesEqual) default: return false } @@ -693,16 +723,39 @@ extension GenerationSchema { } else if type == String.self { return (.string(StringNode(description: description, pattern: nil, enumChoices: nil)), [:]) } else if type == Int.self { + var minimum: Double? + var maximum: Double? + for guide in guides { + if let min = guide.minimum { minimum = min } + if let max = guide.maximum { maximum = max } + } return ( - .number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: true)), [:] + .number(NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: true)), [:] ) } else if type == Float.self || type == Double.self || type == Decimal.self { + var minimum: Double? + var maximum: Double? + for guide in guides { + if let min = guide.minimum { minimum = min } + if let max = guide.maximum { maximum = max } + } return ( - .number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: false)), [:] + .number(NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: false)), [:] ) } else { // Complex type - use its schema let schema = Value.generationSchema + + // Arrays should be inlined, not referenced + if case .array(var arrayNode) = schema.root { + arrayNode.description = description + for guide in guides { + if let min = guide.minimumCount { arrayNode.minItems = min } + if let max = guide.maximumCount { arrayNode.maxItems = max } + } + return (.array(arrayNode), schema.defs) + } + let typeName = String(reflecting: Value.self) var deps = schema.defs @@ -800,4 +853,40 @@ extension GenerationSchema { /// let data = try encoder.encode(schema) /// ``` static let omitAdditionalPropertiesKey = CodingUserInfoKey(rawValue: "GenerationSchema.omitAdditionalProperties")! + + package func schemaPrompt() -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + guard let data = try? encoder.encode(self), + let schemaJSON = String(data: data, encoding: .utf8) else { + return "Respond with valid JSON only." + } + return "Respond with valid JSON matching this schema:\n\(schemaJSON)" + } +} + +extension Character { + package static let jsonQuoteScalars: Set = [0x22, 0x201C, 0x201D, 0x2018, 0x2019] + package static let jsonAllowedWhitespaceCharacters: Set = [" ", "\t", "\n"] + + package var containsEmojiScalar: Bool { + unicodeScalars.contains { scalar in + scalar.properties.isEmojiPresentation || scalar.properties.isEmoji + } + } + + package var isValidJSONStringCharacter: Bool { + guard self != "\\" else { return false } + guard let scalar = unicodeScalars.first, scalar.value >= 0x20 else { return false } + guard !Self.jsonQuoteScalars.contains(scalar.value) else { return false } + + if let ascii = asciiValue { + let char = Character(UnicodeScalar(ascii)) + if Self.jsonAllowedWhitespaceCharacters.contains(char) { return true } + return isLetter || isNumber || (isASCII && (isPunctuation || isSymbol)) + } + + // Allow non-ASCII letters/numbers and emoji, but disallow non-ASCII punctuation (e.g. "】") + return isLetter || isNumber || containsEmojiScalar + } } diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index a57da336..33ddba6c 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -478,14 +478,8 @@ import Foundation includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("LlamaLanguageModel only supports generating String content") - } - // Validate that no image segments are present try validateNoImageSegments(in: session) - try await ensureModelLoaded() let runtimeOptions = resolvedOptions(from: options) @@ -495,7 +489,6 @@ import Foundation guard let context = llama_init_from_model(model!, contextParams) else { throw LlamaLanguageModelError.contextInitializationFailed } - defer { llama_free(context) } // Check if this is an embedding model (no KV cache). @@ -510,22 +503,48 @@ import Foundation llama_set_warmup(context, false) llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads) - let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 - let fullPrompt = try formatPrompt(for: session) + let fullPrompt: String + if includeSchemaInPrompt, type != String.self { + fullPrompt = try formatPrompt( + for: session, + extraSystemMessage: type.generationSchema.schemaPrompt() + ) + } else { + fullPrompt = try formatPrompt(for: session) + } - let text = try await generateText( - context: context, - model: model!, - prompt: fullPrompt, - maxTokens: maxTokens, - options: runtimeOptions - ) + if type == String.self { + let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 + let text = try await generateText( + context: context, + model: model!, + prompt: fullPrompt, + maxTokens: maxTokens, + options: runtimeOptions + ) - return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), - transcriptEntries: ArraySlice([]) - ) + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice([]) + ) + } else { + let maxTokens = runtimeOptions.maximumResponseTokens ?? 512 + let jsonString = try generateStructuredJSON( + context: context, + prompt: fullPrompt, + schema: type.generationSchema, + maxTokens: maxTokens, + options: runtimeOptions + ) + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice([]) + ) + } } public func streamResponse( @@ -840,6 +859,208 @@ import Foundation return generatedText } + // MARK: - Structured JSON Generation + + private func generateStructuredJSON( + context: OpaquePointer, + prompt: String, + schema: GenerationSchema, + maxTokens: Int, + options: ResolvedGenerationOptions + ) throws -> String { + guard let vocab = llama_model_get_vocab(model!) else { + throw LlamaLanguageModelError.contextInitializationFailed + } + + let promptTokens = try tokenizeText(vocab: vocab, text: prompt) + guard !promptTokens.isEmpty else { + throw LlamaLanguageModelError.tokenizationFailed + } + + var batch = llama_batch_init(Int32(options.batchSize), 0, 1) + defer { llama_batch_free(batch) } + + let hasEncoder = try prepareInitialBatch( + batch: &batch, + promptTokens: promptTokens, + model: model!, + vocab: vocab, + context: context, + batchSize: options.batchSize + ) + + guard let sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) else { + throw LlamaLanguageModelError.decodingFailed + } + defer { llama_sampler_free(sampler) } + let samplerPointer = UnsafeMutablePointer(sampler) + + if options.repeatPenalty != 1.0 || options.frequencyPenalty != 0.0 || options.presencePenalty != 0.0 { + llama_sampler_chain_add( + samplerPointer, + llama_sampler_init_penalties( + options.repeatLastN, + options.repeatPenalty, + options.frequencyPenalty, + options.presencePenalty + ) + ) + } + applySampling(sampler: samplerPointer, effectiveTemperature: options.temperature, options: options) + + let vocabSize = Int(llama_vocab_n_tokens(vocab)) + let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens + + return try withUnsafeMutablePointer(to: &batch) { batchPointer in + var backend = LlamaTokenBackend( + context: context, + vocab: vocab, + vocabSize: vocabSize, + sampler: samplerPointer, + batch: batchPointer, + position: initialPosition, + maximumTokens: maxTokens, + tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) } + ) + var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) + return try generator.generate() + } + } + + private struct LlamaTokenBackend: TokenBackend { + let context: OpaquePointer + let vocab: OpaquePointer + let vocabSize: Int + let sampler: UnsafeMutablePointer + let batch: UnsafeMutablePointer + let tokenToTextFn: (Int) -> String? + let tokensExcludedFromRepetitionPenalty: Set + let endTokens: Set + + var position: Int32 + var remainingTokens: Int + let totalTokenBudget: Int + let eosToken: Int + + init( + context: OpaquePointer, + vocab: OpaquePointer, + vocabSize: Int, + sampler: UnsafeMutablePointer, + batch: UnsafeMutablePointer, + position: Int32, + maximumTokens: Int, + tokenToTextFn: @escaping (Int) -> String? + ) { + self.context = context + self.vocab = vocab + self.vocabSize = vocabSize + self.sampler = sampler + self.batch = batch + self.position = position + self.remainingTokens = maximumTokens + self.totalTokenBudget = maximumTokens + self.eosToken = Int(llama_vocab_eos(vocab)) + + let eotTokenValue = llama_vocab_eot(vocab) + let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken + self.endTokens = [self.eosToken, endOfTurnToken] + + self.tokenToTextFn = tokenToTextFn + self.tokensExcludedFromRepetitionPenalty = Self.buildTokensExcludedFromRepetitionPenalty( + vocabSize: vocabSize, + tokenToText: tokenToTextFn + ) + } + + func isSpecialToken(_ token: Int) -> Bool { + let attributes = llama_vocab_get_attr(vocab, llama_token(token)) + return (attributes.rawValue & LLAMA_TOKEN_ATTR_CONTROL.rawValue) != 0 + } + + private static func buildTokensExcludedFromRepetitionPenalty( + vocabSize: Int, + tokenToText: (Int) -> String? + ) -> Set { + let excludedTexts: Set = ["{", "}", "[", "]", ",", ":", "\""] + var excluded = Set() + excluded.reserveCapacity(excludedTexts.count * 4) + + for token in 0 ..< vocabSize { + guard let text = tokenToText(token) else { continue } + let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) + if excludedTexts.contains(trimmed) { + excluded.insert(token) + } + } + + return excluded + } + + func tokenize(_ text: String) throws -> [Int] { + let utf8Count = text.utf8.count + let capacity = Int32(max(utf8Count * 2, 8)) + let tokens = UnsafeMutablePointer.allocate(capacity: Int(capacity)) + defer { tokens.deallocate() } + + let tokenCount = llama_tokenize( + vocab, + text, + Int32(utf8Count), + tokens, + capacity, + false, + false + ) + guard tokenCount > 0 else { return [] } + return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount))).map { Int($0) } + } + + func tokenText(_ token: Int) -> String? { + tokenToTextFn(token) + } + + mutating func decode(_ token: Int) throws { + let llamaToken = llama_token(token) + + batch.pointee.n_tokens = 1 + batch.pointee.token[0] = llamaToken + batch.pointee.pos[0] = position + batch.pointee.n_seq_id[0] = 1 + if let seqIds = batch.pointee.seq_id, let seqId = seqIds[0] { + seqId[0] = 0 + } + batch.pointee.logits[0] = 1 + + position += 1 + remainingTokens -= 1 + + let decodeResult = llama_decode(context, batch.pointee) + guard decodeResult == 0 else { + throw LlamaLanguageModelError.decodingFailed + } + + if !tokensExcludedFromRepetitionPenalty.contains(Int(llamaToken)) { + llama_sampler_accept(sampler, llamaToken) + } + } + + mutating func sample(from allowedTokens: Set) throws -> Int { + guard let logits = llama_get_logits(context) else { + return eosToken + } + + for tokenIndex in 0 ..< vocabSize { + if !allowedTokens.contains(tokenIndex) { + logits[tokenIndex] = -Float.infinity + } + } + + let tokenIndex = batch.pointee.n_tokens - 1 + return Int(llama_sampler_sample(sampler, context, tokenIndex)) + } + } + private func generateTextStream( context: OpaquePointer, model: OpaquePointer, @@ -1082,7 +1303,10 @@ import Foundation return hasEncoder } - private func formatPrompt(for session: LanguageModelSession) throws -> String { + private func formatPrompt( + for session: LanguageModelSession, + extraSystemMessage: String? = nil + ) throws -> String { guard let model = self.model else { throw LlamaLanguageModelError.modelLoadFailed } @@ -1114,6 +1338,10 @@ import Foundation } } + if let extraSystemMessage, !extraSystemMessage.isEmpty { + messages.append(("system", extraSystemMessage)) + } + // Keep C strings alive while using them let cRoles = messages.map { strdup($0.role) } let cContents = messages.map { strdup($0.content) } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 77accfb5..12a57ae2 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -58,11 +58,6 @@ import Foundation includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("MLXLanguageModel only supports generating String content") - } - let context: ModelContext if let directory { context = try await loadModel(directory: directory) @@ -72,6 +67,23 @@ import Foundation context = try await loadModel(id: modelId) } + if type != String.self { + let jsonString = try await generateStructuredJSON( + context: context, + session: session, + prompt: prompt, + schema: type.generationSchema, + options: options + ) + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice([]) + ) + } + // Convert session tools to MLX ToolSpec format let toolSpecs: [ToolSpec]? = session.tools.isEmpty @@ -168,7 +180,7 @@ import Foundation options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { guard type == String.self else { - fatalError("MLXLanguageModel only supports generating String content") + fatalError("MLXLanguageModel streaming only supports String content") } let modelId = self.modelId @@ -189,7 +201,6 @@ import Foundation let generateParameters = toGenerateParameters(options) - // Build chat history from full transcript let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let userInput = MLXLMCommon.UserInput( @@ -248,6 +259,20 @@ import Foundation ) } + private func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { + MLXLMCommon.GenerateParameters( + maxTokens: options.maximumResponseTokens, + maxKVSize: nil, + kvBits: nil, + kvGroupSize: 64, + quantizedKVStart: 0, + temperature: Float(options.temperature ?? 0.2), + topP: 0.95, + repetitionPenalty: 1.1, + repetitionContextSize: 64 + ) + } + // MARK: - Transcript Conversion private func convertTranscriptToMLXChat( @@ -358,28 +383,72 @@ import Foundation private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec { // Convert AnyLanguageModel's GenerationSchema to JSON-compatible dictionary - let parametersDict: [String: Any] + let parametersDict: [String: any Sendable] do { let resolvedSchema = tool.parameters.withResolvedRoot() ?? tool.parameters let encoder = JSONEncoder() let data = try encoder.encode(resolvedSchema) if let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] { - parametersDict = json + parametersDict = try convertToSendableJSONObject(json) } else { - parametersDict = ["type": "object", "properties": [:], "required": []] + parametersDict = makeEmptyJSONSchemaObject() } } catch { - parametersDict = ["type": "object", "properties": [:], "required": []] + parametersDict = makeEmptyJSONSchemaObject() } - return [ + let functionSpec: [String: any Sendable] = [ + "name": tool.name, + "description": tool.description, + "parameters": parametersDict, + ] + + let toolSpec: ToolSpec = [ "type": "function", - "function": [ - "name": tool.name, - "description": tool.description, - "parameters": parametersDict, - ], + "function": functionSpec, ] + + return toolSpec + } + + private func makeEmptyJSONSchemaObject() -> [String: any Sendable] { + [ + "type": "object", + "properties": [String: any Sendable](), + "required": [String](), + ] + } + + private func convertToSendableJSONObject(_ object: [String: Any]) throws -> [String: any Sendable] { + var converted: [String: any Sendable] = [:] + converted.reserveCapacity(object.count) + + for (key, value) in object { + converted[key] = try convertToSendableJSONValue(value) + } + return converted + } + + private func convertToSendableJSONValue(_ value: Any) throws -> any Sendable { + if value is NSNull { return MLXLMCommon.JSONValue.null } + if let stringValue = value as? String { return stringValue } + if let boolValue = value as? Bool { return boolValue } + if let intValue = value as? Int { return intValue } + if let doubleValue = value as? Double { return doubleValue } + if let numberValue = value as? NSNumber { + if CFGetTypeID(numberValue) == CFBooleanGetTypeID() { + return numberValue.boolValue + } + return numberValue.doubleValue + } + if let arrayValue = value as? [Any] { + return try arrayValue.map { try convertToSendableJSONValue($0) } + } + if let dictionaryValue = value as? [String: Any] { + return try convertToSendableJSONObject(dictionaryValue) + } + + throw StructuredGenerationError.unsupportedJSONValueType } // MARK: - Tool Invocation Handling @@ -464,4 +533,251 @@ import Foundation } return textParts.joined(separator: "\n") } + + // MARK: - Structured JSON Generation + + private enum StructuredGenerationError: Error { + case invalidVocabSize + case unsupportedJSONValueType + } + + private func generateStructuredJSON( + context: ModelContext, + session: LanguageModelSession, + prompt: Prompt, + schema: GenerationSchema, + options: GenerationOptions + ) async throws -> String { + let maxTokens = options.maximumResponseTokens ?? 512 + let generateParameters = toStructuredGenerateParameters(options) + + let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schema.schemaPrompt()) + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil + ) + let lmInput = try await context.processor.prepare(input: userInput) + + let backend = try MLXTokenBackend( + context: context, + input: lmInput, + parameters: generateParameters, + maximumTokens: maxTokens + ) + + var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) + let json = try generator.generate() + Stream().synchronize() + return json + } + + private func normalizeChatForStructuredGeneration( + _ chat: [MLXLMCommon.Chat.Message], + schemaPrompt: String + ) -> [MLXLMCommon.Chat.Message] { + var systemMessageParts: [String] = [] + systemMessageParts.append(schemaPrompt) + + var messages: [MLXLMCommon.Chat.Message] = [] + messages.reserveCapacity(chat.count) + + for message in chat { + if message.role == .system { + systemMessageParts.append(message.content) + continue + } + + if let last = messages.last, last.role == message.role { + let merged = MLXLMCommon.Chat.Message(role: last.role, content: "\(last.content)\n\(message.content)") + messages.removeLast() + messages.append(merged) + } else { + messages.append(message) + } + } + + let systemPrefix = systemMessageParts + .map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + .filter { !$0.isEmpty } + .joined(separator: "\n\n") + + guard !systemPrefix.isEmpty else { + return messages + } + + if let firstUserIndex = messages.firstIndex(where: { $0.role == .user }) { + let existing = messages[firstUserIndex].content + messages[firstUserIndex] = MLXLMCommon.Chat.Message(role: .user, content: "\(systemPrefix)\n\n\(existing)") + return messages + } + + messages.insert(.init(role: .user, content: systemPrefix), at: 0) + return messages + } + + private struct MLXTokenBackend: TokenBackend { + let model: any MLXLMCommon.LanguageModel + let tokenizer: any Tokenizer + var state: MLXLMCommon.LMOutput.State? + var cache: [MLXLMCommon.KVCache] + var processor: MLXLMCommon.LogitProcessor? + let sampler: MLXLMCommon.LogitSampler + let tokensExcludedFromRepetitionPenalty: Set + let endTokens: Set + + var currentLogits: MLXArray + let vocabSize: Int + let eosToken: Int + var remainingTokens: Int + let totalTokenBudget: Int + + init( + context: ModelContext, + input: MLXLMCommon.LMInput, + parameters: MLXLMCommon.GenerateParameters, + maximumTokens: Int + ) throws { + self.model = context.model + self.tokenizer = context.tokenizer + self.state = nil + self.cache = context.model.newCache(parameters: parameters) + self.processor = parameters.processor() + self.sampler = parameters.sampler() + self.remainingTokens = maximumTokens + self.totalTokenBudget = maximumTokens + guard let eosTokenId = context.tokenizer.eosTokenId else { + throw StructuredGenerationError.invalidVocabSize + } + self.eosToken = eosTokenId + self.endTokens = Self.buildEndTokens( + eosTokenId: eosTokenId, + tokenizer: context.tokenizer, + configuration: context.configuration + ) + + self.tokensExcludedFromRepetitionPenalty = Self.buildTokensExcludedFromRepetitionPenalty(tokenizer: context.tokenizer) + + processor?.prompt(input.text.tokens) + + let prepareResult = try context.model.prepare( + input, + cache: cache, + windowSize: parameters.prefillStepSize + ) + + let output: MLXLMCommon.LMOutput + switch prepareResult { + case .tokens(let tokensToProcess): + output = context.model( + tokensToProcess[text: .newAxis], + cache: cache, + state: state + ) + case .logits(let logitsOutput): + output = logitsOutput + } + + self.state = output.state + self.currentLogits = output.logits + + guard output.logits.shape.count >= 1 else { + throw StructuredGenerationError.invalidVocabSize + } + self.vocabSize = output.logits.shape.last ?? 0 + guard self.vocabSize > 0 else { + throw StructuredGenerationError.invalidVocabSize + } + } + + private static func buildEndTokens( + eosTokenId: Int, + tokenizer: any Tokenizer, + configuration: ModelConfiguration + ) -> Set { + var tokens: Set = [eosTokenId] + + // If the tokenizer declares an EOS token string, prefer treating its ID as an end token too. + // Some chat models use a string EOS marker (e.g. "") whose ID may differ from eosTokenId. + if let eosString = tokenizer.eosToken, let eosStringId = tokenizer.convertTokenToId(eosString) { + tokens.insert(eosStringId) + } + + for tokenString in configuration.extraEOSTokens { + if let id = tokenizer.convertTokenToId(tokenString) { + tokens.insert(id) + } + } + return tokens + } + + func isSpecialToken(_ token: Int) -> Bool { + // Use swift-transformers' own special token registry (skipSpecialTokens) instead of guessing. + let raw = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + guard !raw.isEmpty else { return false } + let filtered = tokenizer.decode(tokens: [token], skipSpecialTokens: true) + return filtered.isEmpty + } + + private static func buildTokensExcludedFromRepetitionPenalty(tokenizer: any Tokenizer) -> Set { + let excludedTexts = ["{", "}", "[", "]", ",", ":", "\""] + var excluded = Set() + excluded.reserveCapacity(excludedTexts.count * 2) + + for text in excludedTexts { + let tokens = tokenizer.encode(text: text, addSpecialTokens: false) + for token in tokens { + excluded.insert(token) + } + } + + return excluded + } + + func tokenize(_ text: String) throws -> [Int] { + tokenizer.encode(text: text, addSpecialTokens: false) + } + + func tokenText(_ token: Int) -> String? { + let decoded = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + return decoded.isEmpty ? nil : decoded + } + + mutating func decode(_ token: Int) throws { + let inputText = MLXLMCommon.LMInput.Text(tokens: MLXArray([Int32(token)])) + let output = model( + inputText[text: .newAxis], + cache: cache.isEmpty ? nil : cache, + state: state + ) + state = output.state + currentLogits = output.logits + remainingTokens -= 1 + + if !tokensExcludedFromRepetitionPenalty.contains(token) { + let tokenArray = MLXArray(Int32(token)) + processor?.didSample(token: tokenArray) + } + } + + mutating func sample(from allowedTokens: Set) throws -> Int { + guard !allowedTokens.isEmpty else { + throw ConstrainedGenerationError.tokenizationFailed + } + + var logits = currentLogits[0..., -1, 0...] + logits = processor?.process(logits: logits) ?? logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + let allowedIndices = MLXArray(allowedTokens.map { UInt32($0) }) + let maskedLogits = full(logits.shape, values: -Float.infinity) + maskedLogits[0..., allowedIndices] = logits[0..., allowedIndices] + + let sampledToken = sampler.sample(logits: maskedLogits) + return sampledToken.item(Int.self) + } + } #endif // MLX diff --git a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift index df0b9f54..0360e998 100644 --- a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift @@ -81,17 +81,19 @@ transcript: session.transcript.toFoundationModels(instructions: session.instructions) ) - let fmResponse = try await fmSession.respond(to: fmPrompt, options: fmOptions) - let generatedContent = GeneratedContent(fmResponse.content) - if type == String.self { + let fmResponse = try await fmSession.respond(to: fmPrompt, options: fmOptions) + let generatedContent = GeneratedContent(fmResponse.content) return LanguageModelSession.Response( content: fmResponse.content as! Content, rawContent: generatedContent, transcriptEntries: [] ) } else { - // For non-String types, try to create an instance from the generated content + // For non-String types, use schema-based structured generation + let schema = FoundationModels.GenerationSchema(type.generationSchema) + let fmResponse = try await fmSession.respond(to: fmPrompt, schema: schema, options: fmOptions) + let generatedContent = try AnyLanguageModel.GeneratedContent(fmResponse.content) let content = try type.init(generatedContent) return LanguageModelSession.Response( @@ -321,25 +323,15 @@ extension FoundationModels.GenerationSchema { internal init(_ content: AnyLanguageModel.GenerationSchema) { let resolvedSchema = content.withResolvedRoot() ?? content - - let rawParameters = try? JSONValue(resolvedSchema) - var schema: FoundationModels.GenerationSchema? = nil - if rawParameters?.objectValue is [String: JSONValue] { - if let data = try? JSONEncoder().encode(rawParameters) { - if let jsonSchema = try? JSONDecoder().decode(JSONSchema.self, from: data) { - let dynamicSchema = convertToDynamicSchema(jsonSchema) - schema = try? FoundationModels.GenerationSchema(root: dynamicSchema, dependencies: []) - } - } + let dynamicSchema = convertToDynamicSchema(resolvedSchema.root) + let dependencies = resolvedSchema.defs.map { name, node in + convertToDynamicSchema(node, name: name) } - if let schema = schema { + + if let schema = try? FoundationModels.GenerationSchema(root: dynamicSchema, dependencies: dependencies) { self = schema } else { - self = FoundationModels.GenerationSchema( - type: String.self, - properties: [] - ) - + self = FoundationModels.GenerationSchema(type: String.self, properties: []) } } } @@ -368,6 +360,78 @@ } } + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + func convertToDynamicSchema( + _ node: GenerationSchema.Node, + name: String? = nil + ) -> FoundationModels.DynamicGenerationSchema { + switch node { + case .object(let objectNode): + return .init( + name: name ?? "", + description: objectNode.description, + properties: objectNode.properties.map { key, value in + .init( + name: key, + description: value.nodeDescription, + schema: convertToDynamicSchema(value), + isOptional: !objectNode.required.contains(key) + ) + } + ) + + case .string(let stringNode): + if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { + return .init( + name: name ?? "", + description: stringNode.description, + anyOf: enumChoices.map { .init(type: String.self, guides: [.constant($0)]) } + ) + } + if let pattern = stringNode.pattern, let regex = try? Regex(pattern) { + return .init(type: String.self, guides: [.pattern(regex)]) + } + return .init(type: String.self) + + case .number(let numberNode): + return numberNode.integerOnly + ? .init(type: Int.self, guides: intGuides(numberNode)) + : .init(type: Double.self, guides: doubleGuides(numberNode)) + + case .boolean: + return .init(type: Bool.self) + + case .array(let arrayNode): + return .init( + arrayOf: convertToDynamicSchema(arrayNode.items), + minimumElements: arrayNode.minItems, + maximumElements: arrayNode.maxItems + ) + + case .anyOf(let nodes): + return .init(name: "", anyOf: nodes.map { convertToDynamicSchema($0) }) + + case .ref(let refName): + return .init(referenceTo: refName) + } + } + + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + private func intGuides(_ numberNode: GenerationSchema.NumberNode) -> [FoundationModels.GenerationGuide] { + var guides: [FoundationModels.GenerationGuide] = [] + if let minimum = numberNode.minimum { guides.append(.minimum(Int(minimum))) } + if let maximum = numberNode.maximum { guides.append(.maximum(Int(maximum))) } + return guides + } + + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + private func doubleGuides(_ numberNode: GenerationSchema.NumberNode) -> [FoundationModels.GenerationGuide] { + var guides: [FoundationModels.GenerationGuide] = [] + if let minimum = numberNode.minimum { guides.append(.minimum(minimum)) } + if let maximum = numberNode.maximum { guides.append(.maximum(maximum)) } + return guides + } + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) func convertToDynamicSchema(_ jsonSchema: JSONSchema) -> FoundationModels.DynamicGenerationSchema { switch jsonSchema { @@ -378,10 +442,13 @@ return .init(name: "", description: jsonSchema.description, properties: schemaProperties) case .string(_, _, _, _, _, _, _, _, pattern: let pattern, _): - var guides: [FoundationModels.GenerationGuide] = [] - if let values = jsonSchema.enum?.compactMap(\.stringValue), !values.isEmpty { - guides.append(.anyOf(values)) + if let enumValues = jsonSchema.enum?.compactMap(\.stringValue), !enumValues.isEmpty { + let enumSchemas = enumValues.map { + FoundationModels.DynamicGenerationSchema(type: String.self, guides: [.constant($0)]) + } + return .init(name: "", description: jsonSchema.description, anyOf: enumSchemas) } + var guides: [FoundationModels.GenerationGuide] = [] if let value = jsonSchema.const?.stringValue { guides.append(.constant(value)) } @@ -397,12 +464,8 @@ } var guides: [FoundationModels.GenerationGuide] = [] - if let min = minimum { - guides.append(.minimum(min)) - } - if let max = maximum { - guides.append(.maximum(max)) - } + if let minimum { guides.append(.minimum(minimum)) } + if let maximum { guides.append(.maximum(maximum)) } if let value = jsonSchema.const?.intValue { guides.append(.range(value ... value)) } @@ -415,12 +478,8 @@ } var guides: [FoundationModels.GenerationGuide] = [] - if let min = minimum { - guides.append(.minimum(min)) - } - if let max = maximum { - guides.append(.maximum(max)) - } + if let minimum { guides.append(.minimum(minimum)) } + if let maximum { guides.append(.maximum(maximum)) } if let value = jsonSchema.const?.doubleValue { guides.append(.range(value ... value)) } diff --git a/Sources/AnyLanguageModel/StructuredGeneration.swift b/Sources/AnyLanguageModel/StructuredGeneration.swift new file mode 100644 index 00000000..9a8bd0ed --- /dev/null +++ b/Sources/AnyLanguageModel/StructuredGeneration.swift @@ -0,0 +1,314 @@ +import Foundation + +// MARK: - Token Backend + +/// Abstracts token-level operations for structured JSON generation. +package protocol TokenBackend { + func tokenize(_ text: String) throws -> [Int] + func tokenText(_ token: Int) -> String? + func isSpecialToken(_ token: Int) -> Bool + mutating func decode(_ token: Int) throws + mutating func sample(from allowedTokens: Set) throws -> Int + + var eosToken: Int { get } + var endTokens: Set { get } + var vocabSize: Int { get } + var remainingTokens: Int { get set } + var totalTokenBudget: Int { get } +} + +// MARK: - JSON Generator + +/// Generates JSON conforming to a schema using constrained token sampling. +package struct ConstrainedJSONGenerator { + private var backend: Backend + private let schema: GenerationSchema + + private let quoteToken: Int + + private let stringTerminators: Set + private let stringInitialAllowedTokens: Set + private let stringContinuationAllowedTokens: Set + + private let basicTerminators: Set + private let integerTerminators: Set + private let doubleTerminators: Set + + package init(backend: Backend, schema: GenerationSchema) throws { + self.backend = backend + self.schema = schema + + guard let quoteToken = try backend.tokenize("\"").first else { + throw ConstrainedGenerationError.tokenizationFailed + } + self.quoteToken = quoteToken + + self.stringTerminators = backend.endTokens.union([quoteToken]) + + var structuralTerminators = backend.endTokens + for structuralText in [",", "}", "]", ":"] { + if let token = try backend.tokenize(structuralText).first { + structuralTerminators.insert(token) + } + } + self.basicTerminators = structuralTerminators + self.integerTerminators = Self.buildValidIntegerTokens(backend: backend).union(structuralTerminators) + self.doubleTerminators = Self.buildValidDecimalTokens(backend: backend).union(structuralTerminators) + + let stringContentTokens = Self.buildValidStringTokens(backend: backend) + self.stringInitialAllowedTokens = stringContentTokens + self.stringContinuationAllowedTokens = stringContentTokens.union(stringTerminators) + } + + package mutating func generate() throws -> String { + try generateNode(schema.root) + } + + private static func buildValidStringTokens(backend: Backend) -> Set { + let allowedWhitespace: Set = [" ", "\t", "\n"] + var allowed = Set() + allowed.reserveCapacity(backend.vocabSize / 4) + + for token in 0 ..< backend.vocabSize { + if backend.endTokens.contains(token) { continue } + if backend.isSpecialToken(token) { continue } + guard let text = backend.tokenText(token), !text.isEmpty else { continue } + guard text.allSatisfy({ $0.isValidJSONStringCharacter }) else { continue } + + if text.allSatisfy({ $0.isWhitespace }) { + if text.count == 1, let char = text.first, allowedWhitespace.contains(char) { + allowed.insert(token) + } + } else { + allowed.insert(token) + } + } + return allowed + } + + private static func buildValidIntegerTokens(backend: Backend) -> Set { + var allowed = Set() + for token in 0 ..< backend.vocabSize { + guard let text = backend.tokenText(token), !text.isEmpty else { continue } + if text.allSatisfy({ $0.isNumber || $0 == "-" }) { + allowed.insert(token) + } + } + return allowed + } + + private static func buildValidDecimalTokens(backend: Backend) -> Set { + var allowed = Set() + for token in 0 ..< backend.vocabSize { + guard let text = backend.tokenText(token), !text.isEmpty else { continue } + if text.allSatisfy({ $0.isNumber || $0 == "-" || $0 == "." }) { + allowed.insert(token) + } + } + return allowed + } + + private mutating func emit(_ text: String) throws -> String { + for token in try backend.tokenize(text) { + guard backend.remainingTokens > 0 else { + throw ConstrainedGenerationError.tokenBudgetExceeded + } + try backend.decode(token) + } + return text + } + + private func maxFreeStringTokens() -> Int { + let perStringLimit = max(32, backend.totalTokenBudget / 4) + return min(backend.remainingTokens, perStringLimit) + } + + private mutating func generateFreeString(maxTokens: Int) throws -> String { + var result = "" + var generated = 0 + + while backend.remainingTokens > 0, generated < maxTokens { + let allowed = result.isEmpty ? stringInitialAllowedTokens : stringContinuationAllowedTokens + let token = try backend.sample(from: allowed) + if stringTerminators.contains(token) { break } + + var text = backend.tokenText(token) ?? "" + if result.last?.isWhitespace == true && text.first?.isWhitespace == true { + text = String(text.drop(while: { $0.isWhitespace })) + } + result += text + generated += 1 + try backend.decode(token) + } + + return result + } + + private mutating func generateChoice(_ candidates: [String]) throws -> String { + let tokenized = try candidates.map { try backend.tokenize($0) }.filter { !$0.isEmpty } + guard !tokenized.isEmpty else { + throw ConstrainedGenerationError.tokenizationFailed + } + + var prefixes = tokenized + var emitted = "" + var position = 0 + + while backend.remainingTokens > 0 { + if prefixes.contains(where: { $0.count == position }) { break } + + let allowed = Set(prefixes.compactMap { tokens -> Int? in + guard position < tokens.count else { return nil } + return tokens[position] + }) + + let token = try backend.sample(from: allowed) + emitted += backend.tokenText(token) ?? "" + try backend.decode(token) + + prefixes = prefixes.filter { $0.count > position && $0[position] == token } + position += 1 + if prefixes.isEmpty { break } + } + + return emitted + } + + private mutating func generateNumber(_ node: GenerationSchema.NumberNode) throws -> String { + let allowedTokens = node.integerOnly ? integerTerminators : doubleTerminators + var result = "" + let maxTokens = 16 + + while backend.remainingTokens > 0, result.count < maxTokens { + let token = try backend.sample(from: allowedTokens) + if basicTerminators.contains(token) { break } + + guard let text = backend.tokenText(token) else { break } + result += text + try backend.decode(token) + } + + return clampNumberString(result.isEmpty ? "0" : result, node: node) + } + + private func clampNumberString(_ text: String, node: GenerationSchema.NumberNode) -> String { + if node.integerOnly { + let value = Int(text) ?? 0 + let clamped = clampInt(value, min: node.minimum, max: node.maximum) + return String(clamped) + } else { + let value = Double(text) ?? 0 + let clamped = clampDouble(value, min: node.minimum, max: node.maximum) + return formatDouble(clamped) + } + } + + private func clampInt(_ value: Int, min: Double?, max: Double?) -> Int { + let lower = min.map { Int(ceil($0)) } + let upper = max.map { Int(floor($0)) } + return clamp(value, min: lower, max: upper) + } + + private func clampDouble(_ value: Double, min: Double?, max: Double?) -> Double { + clamp(value, min: min, max: max) + } + + private func clamp(_ value: T, min: T?, max: T?) -> T { + var result = value + if let min { result = Swift.max(result, min) } + if let max { result = Swift.min(result, max) } + return result + } + + private func formatDouble(_ value: Double) -> String { + if value.truncatingRemainder(dividingBy: 1) == 0 { + return String(Int(value)) + } + let formatted = String(format: "%.6g", value) + return formatted + } + + private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { + guard backend.remainingTokens > 0 else { + throw ConstrainedGenerationError.tokenBudgetExceeded + } + + switch node { + case .object(let objectNode): + return try generateObject(objectNode) + case .array(let arrayNode): + return try generateArray(arrayNode) + case .string(let stringNode): + return try generateString(stringNode) + case .number(let numberNode): + return try generateNumber(numberNode) + case .boolean: + return try generateChoice(["true", "false"]) + case .ref(let typeName): + guard let referenced = schema.defs[typeName] else { + throw ConstrainedGenerationError.missingReference(typeName) + } + return try generateNode(referenced) + case .anyOf(let variants): + guard let first = variants.first else { + throw ConstrainedGenerationError.emptyAnyOf + } + return try generateNode(first) + } + } + + private mutating func generateObject(_ node: GenerationSchema.ObjectNode) throws -> String { + let keys = node.properties.keys.sorted() + var output = try emit("{") + + for (index, key) in keys.enumerated() { + output += try emit("\"\(key)\":") + output += try generateNode(node.properties[key] ?? .string(.init())) + + if index < keys.count - 1 { + output += try emit(",") + } + } + + output += try emit("}") + return output + } + + private mutating func generateArray(_ node: GenerationSchema.ArrayNode) throws -> String { + let count = node.minItems ?? node.maxItems ?? 4 + var output = try emit("[") + + for index in 0 ..< count { + output += try generateNode(node.items) + if index < count - 1 { + output += try emit(",") + } + } + + output += try emit("]") + return output + } + + private mutating func generateString(_ node: GenerationSchema.StringNode) throws -> String { + var output = try emit("\"") + + if let choices = node.enumChoices, !choices.isEmpty { + output += try generateChoice(choices) + } else { + let content = try generateFreeString(maxTokens: maxFreeStringTokens()) + output += content.trimmingCharacters(in: .whitespaces) + } + + output += try emit("\"") + return output + } +} + +// MARK: - Errors + +package enum ConstrainedGenerationError: Error { + case tokenizationFailed + case tokenBudgetExceeded + case missingReference(String) + case emptyAnyOf +} diff --git a/Sources/AnyLanguageModelMacros/GenerableMacro.swift b/Sources/AnyLanguageModelMacros/GenerableMacro.swift index a44707e7..2c1d027f 100644 --- a/Sources/AnyLanguageModelMacros/GenerableMacro.swift +++ b/Sources/AnyLanguageModelMacros/GenerableMacro.swift @@ -106,19 +106,15 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { let binding = varDecl.bindings.first, let identifier = binding.pattern.as(IdentifierPatternSyntax.self) { - let propertyName = identifier.identifier.text let propertyType = binding.typeAnnotation?.type.description ?? "String" - let guideInfo = extractGuideInfo(from: varDecl.attributes) properties.append( PropertyInfo( name: propertyName, type: propertyType, - guideDescription: guideInfo.description, - guides: guideInfo.guides, - pattern: guideInfo.pattern + guide: guideInfo ) ) } @@ -140,32 +136,96 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { in: .init(charactersIn: "\"") ) - var guides: [String] = [] - var pattern: String? = nil + var constraints = Constraints() for arg in Array(arguments.dropFirst()) { - let argText = arg.expression.description + let guideExpression = arg.expression + if let parsedPattern = parsePatternFromExpression(guideExpression) { + constraints.pattern = parsedPattern + continue + } - if argText.contains(".pattern(") { - let patternRegex = #/\.pattern\(\"([^\"]*)\"\)/# - if let match = argText.firstMatch(of: patternRegex) { - pattern = String(match.1) - } - } else if argText.contains("pattern(") { - let patternRegex = #/pattern\(\"([^\"]*)\"\)/# - if let match = argText.firstMatch(of: patternRegex) { - pattern = String(match.1) - } - } else { - guides.append(argText) + if let functionCall = guideExpression.as(FunctionCallExprSyntax.self) { + applyConstraints(from: functionCall, into: &constraints) + } else if let memberAccess = guideExpression.as(MemberAccessExprSyntax.self), + let functionCall = memberAccess.base?.as(FunctionCallExprSyntax.self) { + applyConstraints(from: functionCall, into: &constraints) } } - return GuideInfo(description: description, guides: guides, pattern: pattern) + return GuideInfo(description: description, constraints: constraints) } } } - return GuideInfo(description: nil, guides: [], pattern: nil) + return GuideInfo(description: nil, constraints: Constraints()) + } + + private static func applyConstraints(from call: FunctionCallExprSyntax, into constraints: inout Constraints) { + let functionName: String? + if let memberAccess = call.calledExpression.as(MemberAccessExprSyntax.self) { + functionName = memberAccess.declName.baseName.text + } else if let identifier = call.calledExpression.as(DeclReferenceExprSyntax.self) { + functionName = identifier.baseName.text + } else { + functionName = nil + } + + guard let functionName, let firstArgument = call.arguments.first else { return } + + switch functionName { + case "count": + if let intLiteral = firstArgument.expression.as(IntegerLiteralExprSyntax.self), + let value = Int(intLiteral.literal.text) { + constraints.minimumCount = value + constraints.maximumCount = value + } else if let rangeExpression = firstArgument.expression.as(SequenceExprSyntax.self) { + let (minimum, maximum) = parseClosedRangeInt(rangeExpression) + constraints.minimumCount = minimum + constraints.maximumCount = maximum + } + case "minimumCount": + if let intLiteral = firstArgument.expression.as(IntegerLiteralExprSyntax.self), + let value = Int(intLiteral.literal.text) { + constraints.minimumCount = value + } + case "maximumCount": + if let intLiteral = firstArgument.expression.as(IntegerLiteralExprSyntax.self), + let value = Int(intLiteral.literal.text) { + constraints.maximumCount = value + } + case "minimum": + constraints.minimum = parseNumericLiteral(firstArgument.expression) + case "maximum": + constraints.maximum = parseNumericLiteral(firstArgument.expression) + case "range": + if let rangeExpression = firstArgument.expression.as(SequenceExprSyntax.self) { + let (minimum, maximum) = parseClosedRangeDouble(rangeExpression) + constraints.minimum = minimum + constraints.maximum = maximum + } + default: + break + } + } + + private static func parsePatternFromExpression(_ expression: ExprSyntax) -> String? { + if let functionCall = expression.as(FunctionCallExprSyntax.self) { + let functionName: String? + if let memberAccess = functionCall.calledExpression.as(MemberAccessExprSyntax.self) { + functionName = memberAccess.declName.baseName.text + } else if let identifier = functionCall.calledExpression.as(DeclReferenceExprSyntax.self) { + functionName = identifier.baseName.text + } else { + functionName = nil + } + + if functionName == "pattern", + let firstArg = functionCall.arguments.first, + let stringLiteral = firstArg.expression.as(StringLiteralExprSyntax.self) { + return stringLiteral.segments.description.trimmingCharacters(in: .init(charactersIn: "\"")) + } + } + return nil } private static func isDictionaryType(_ type: String) -> Bool { @@ -173,6 +233,78 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { return trimmed.hasPrefix("[") && trimmed.contains(":") && trimmed.hasSuffix("]") } + private static func escapeDescriptionString(_ description: String?) -> String { + guard let description else { return "nil" } + return makeSwiftStringLiteralExpression(description) + } + + /// Escapes text so it can be embedded safely inside generated Swift source as a string literal. + /// + /// Multi-line strings need newlines converted to `\n` escape sequences, and special characters + /// (backslashes and quotes) must be escaped. + private static func makeSwiftStringLiteralExpression(_ value: String) -> String { + let escaped = value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + return "\"\(escaped)\"" + } + + private static func buildGuidesArray(for property: PropertyInfo) -> String { + let baseType = property.type.replacingOccurrences(of: "?", with: "") + + if baseType.hasPrefix("[") && baseType.hasSuffix("]") && !isDictionaryType(baseType) { + if property.guide.constraints.minimumCount != nil || property.guide.constraints.maximumCount != nil { + let minStr = property.guide.constraints.minimumCount.map { String($0) } ?? "nil" + let maxStr = property.guide.constraints.maximumCount.map { String($0) } ?? "nil" + return "[GenerationGuide(minimumCount: \(minStr), maximumCount: \(maxStr))]" + } + return "[]" + } + + if baseType == "Int" || baseType == "Double" || baseType == "Float" { + if property.guide.constraints.minimum != nil || property.guide.constraints.maximum != nil { + let minStr = property.guide.constraints.minimum.map { String($0) } ?? "nil" + let maxStr = property.guide.constraints.maximum.map { String($0) } ?? "nil" + return "[GenerationGuide(minimum: \(minStr), maximum: \(maxStr))]" + } + return "[]" + } + + return "[]" + } + + private static func parseNumericLiteral(_ expression: ExprSyntax) -> Double? { + if let intLiteral = expression.as(IntegerLiteralExprSyntax.self) { + return Double(intLiteral.literal.text) + } else if let floatLiteral = expression.as(FloatLiteralExprSyntax.self) { + return Double(floatLiteral.literal.text) + } else if let prefixExpression = expression.as(PrefixOperatorExprSyntax.self), + prefixExpression.operator.text == "-" { + if let value = parseNumericLiteral(prefixExpression.expression) { + return -value + } + } + return nil + } + + private static func parseClosedRangeInt(_ expression: SequenceExprSyntax) -> (Int?, Int?) { + let elements = Array(expression.elements) + guard elements.count == 3, + let lowerBound = elements[0].as(IntegerLiteralExprSyntax.self), + let upperBound = elements[2].as(IntegerLiteralExprSyntax.self) + else { return (nil, nil) } + return (Int(lowerBound.literal.text), Int(upperBound.literal.text)) + } + + private static func parseClosedRangeDouble(_ expression: SequenceExprSyntax) -> (Double?, Double?) { + let elements = Array(expression.elements) + guard elements.count == 3 else { return (nil, nil) } + let minimum = parseNumericLiteral(elements[0]) + let maximum = parseNumericLiteral(elements[2]) + return (minimum, maximum) + } + private static func extractDictionaryTypes(_ type: String) -> (key: String, value: String)? { let trimmed = type.trimmingCharacters(in: .whitespacesAndNewlines) @@ -601,32 +733,8 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { properties: [PropertyInfo] ) -> DeclSyntax { let propertySchemas = properties.map { prop in - var guidesArray = "[]" - if !prop.guides.isEmpty || prop.pattern != nil { - var guides: [String] = [] - - if let pattern = prop.pattern { - guides.append(".pattern(\"\(pattern)\")") - } - - guides.append(contentsOf: prop.guides) - guidesArray = "[\(guides.joined(separator: ", "))]" - } - - // Escape the description string so it can be safely embedded in generated code. - // Multi-line strings need newlines converted to \n escape sequences, - // and special characters (backslashes, quotes) must be escaped. - let escapedDescription: String - if let desc = prop.guideDescription { - let escaped = - desc - .replacingOccurrences(of: "\\", with: "\\\\") // Escape backslashes first - .replacingOccurrences(of: "\"", with: "\\\"") // Escape quotes - .replacingOccurrences(of: "\n", with: "\\n") // Convert newlines to escape sequences - escapedDescription = "\"\(escaped)\"" - } else { - escapedDescription = "nil" - } + let escapedDescription = escapeDescriptionString(prop.guide.description) + let guidesArray = buildGuidesArray(for: prop) return """ GenerationSchema.Property( @@ -1204,14 +1312,19 @@ private struct EnumCaseInfo { private struct GuideInfo { let description: String? - let guides: [String] - let pattern: String? + let constraints: Constraints +} + +private struct Constraints { + var minimumCount: Int? + var maximumCount: Int? + var minimum: Double? + var maximum: Double? + var pattern: String? } private struct PropertyInfo { let name: String let type: String - let guideDescription: String? - let guides: [String] - let pattern: String? + let guide: GuideInfo } diff --git a/Tests/AnyLanguageModelTests/GenerableMacroTests.swift b/Tests/AnyLanguageModelTests/GenerableMacroTests.swift index 1e6c12e7..5c6c242b 100644 --- a/Tests/AnyLanguageModelTests/GenerableMacroTests.swift +++ b/Tests/AnyLanguageModelTests/GenerableMacroTests.swift @@ -34,6 +34,38 @@ struct TestArguments { var age: Int } +@Generable +private enum TestEnum: Equatable { + case optionA + case optionB + case optionC +} + +@Generable +private struct TestNestedInner: Equatable { + var value: String + var count: Int +} + +@Generable +private struct TestNestedOuter: Equatable { + var name: String + var inner: TestNestedInner +} + +@Generable +private struct TestStructWithEnum: Equatable { + var label: String + var choice: TestEnum +} + +@Generable +private struct TestStructWithArray: Equatable { + var title: String + @Guide(.count(3)) + var items: [String] +} + @Suite("Generable Macro") struct GenerableMacroTests { @Test("@Guide description with multiline string") @@ -135,4 +167,36 @@ struct GenerableMacroTests { #expect(args.name == "Bob") #expect(args.age == 25) } + + @Test("Enum round-trip conversion") + func enumRoundTrip() throws { + for choice in [TestEnum.optionA, TestEnum.optionB, TestEnum.optionC] { + let restored = try TestEnum(choice.generatedContent) + #expect(choice == restored) + } + } + + @Test("Nested struct round-trip conversion") + func nestedStructRoundTrip() throws { + let original = TestNestedOuter( + name: "outer", + inner: TestNestedInner(value: "inner", count: 42) + ) + let restored = try TestNestedOuter(original.generatedContent) + #expect(original == restored) + } + + @Test("Struct with enum round-trip conversion") + func structWithEnumRoundTrip() throws { + let original = TestStructWithEnum(label: "test", choice: .optionB) + let restored = try TestStructWithEnum(original.generatedContent) + #expect(original == restored) + } + + @Test("Struct with array round-trip conversion") + func structWithArrayRoundTrip() throws { + let original = TestStructWithArray(title: "list", items: ["a", "b", "c"]) + let restored = try TestStructWithArray(original.generatedContent) + #expect(original == restored) + } } diff --git a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift new file mode 100644 index 00000000..787e2930 --- /dev/null +++ b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift @@ -0,0 +1,442 @@ +import Foundation +import Testing + +@testable import AnyLanguageModel + +@Generable +enum Priority: Equatable { + case low + case medium + case high +} + +@Generable +struct SimpleString: Equatable { + @Guide(description: "A greeting message") + var message: String +} + +@Generable +struct SimpleInt: Equatable { + @Guide(description: "A count value", .minimum(0)) + var count: Int +} + +@Generable +struct SimpleBool: Equatable { + @Guide(description: "A boolean flag") + var value: Bool +} + +@Generable +struct SimpleDouble: Equatable { + @Guide(description: "A temperature value") + var temperature: Double +} + +@Generable +struct OptionalFields: Equatable { + @Guide(description: "A required name") + var name: String + + @Guide(description: "An optional nickname") + var nickname: String? +} + +@Generable +struct BasicStruct: Equatable { + @Guide(description: "Person's name") + var name: String + + @Guide(description: "Person's age", .minimum(0)) + var age: Int + + @Guide(description: "Is the person active") + var isActive: Bool + + @Guide(description: "Score value") + var score: Double +} + +@Generable +struct Address: Equatable { + @Guide(description: "Street name") + var street: String + + @Guide(description: "City name") + var city: String + + @Guide(description: "Postal code") + var postalCode: String +} + +@Generable +struct ReusedNestedStruct: Equatable { + @Guide(description: "Some text") + var text: String +} + +@Generable +struct ContainerWithDuplicateNestedType: Equatable { + var first: ReusedNestedStruct + var second: ReusedNestedStruct +} + +@Generable +struct Person: Equatable { + @Guide(description: "Person's name") + var name: String + + @Guide(description: "Person's age") + var age: Int + + var address: Address +} + +@Generable +struct TaskItem: Equatable { + @Guide(description: "Task title") + var title: String + + var priority: Priority + + @Guide(description: "Is completed") + var isCompleted: Bool +} + +@Generable +struct SimpleArray: Equatable { + @Guide(description: "A list of color names") + var colors: [String] +} + +@Generable +struct MultiChoiceQuestion: Equatable { + @Guide(description: "The quiz question") + var text: String + + @Guide(.count(4)) + var choices: [String] + + var answer: String + + @Guide(description: "A brief explanation of why the answer is correct") + var explanation: String +} + +private struct SupportedModel: Sendable { + let name: String + let model: any LanguageModel + + static var all: [SupportedModel] { + func environmentValue(_ key: String) -> String? { + ProcessInfo.processInfo.environment[key] ?? ProcessInfo.processInfo.environment["TEST_RUNNER_\(key)"] + } + + var models: [SupportedModel] = [] + + #if canImport(FoundationModels) + if #available(macOS 26.0, *) { + if SystemLanguageModel.default.isAvailable { + models.append(SupportedModel(name: "SystemLanguageModel", model: SystemLanguageModel.default)) + } + } + #endif + + #if Llama + if let modelPath = environmentValue("LLAMA_MODEL_PATH") { + models.append(SupportedModel(name: "LlamaLanguageModel", model: LlamaLanguageModel(modelPath: modelPath))) + } + #endif + + #if MLX + let shouldRunMLX = environmentValue("ENABLE_MLX_TESTS") != nil + || (environmentValue("CI") == nil + && environmentValue("HF_TOKEN") != nil + && environmentValue("XCTestConfigurationFilePath") != nil) + if let modelId = environmentValue("MLX_MODEL_ID"), shouldRunMLX { + models.append( + SupportedModel( + name: "MLXLanguageModel", + model: MLXLanguageModel(modelId: modelId) + ) + ) + } + #endif + + return models + } +} + +private let supportedModels = SupportedModel.all + +private func isGenerationTestsEnabled() -> Bool { + !supportedModels.isEmpty +} + +@Test("GenerationSchema merges duplicate defs for the same type") +func generationSchemaMergesDuplicateDefsForSameType() { + let schema = ContainerWithDuplicateNestedType.generationSchema + + let nestedTypeName = String(reflecting: ReusedNestedStruct.self) + #expect(schema.defs[nestedTypeName] != nil) +} + +private func testAllModels(_ test: (SupportedModel) async throws -> Void) async { + var failures: [(name: String, error: any Error)] = [] + + for model in supportedModels { + do { + try await test(model) + } catch { + failures.append((model.name, error)) + } + } + + for failure in failures { + Issue.record("[\(failure.name)] \(failure.error)") + } +} + +private func logGenerated(_ content: T, model: String) { + let json = content.generatedContent.jsonString + if let data = json.data(using: .utf8), + let object = try? JSONSerialization.jsonObject(with: data), + let prettyData = try? JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys]), + let prettyJSON = String(data: prettyData, encoding: .utf8) + { + print("\n[\(model)]\n\(prettyJSON)\n") + } else { + print("\n[\(model)]\n\(json)\n") + } +} + +@Suite("Structured Generation", .serialized, .enabled(if: isGenerationTestsEnabled())) +struct StructuredGenerationTests { + @Test("Generate SimpleString with all supported models") + func generateSimpleString() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a greeting message that says hello", + generating: SimpleString.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.message.isEmpty, "[\(model.name)] message should not be empty") + } + } + + @Test("Generate SimpleInt with all supported models") + func generateSimpleInt() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a count value of 42", + generating: SimpleInt.self + ) + + logGenerated(response.content, model: model.name) + #expect(response.content.count >= 0, "[\(model.name)] count should be non-negative") + } + } + + @Test("Generate SimpleDouble with all supported models") + func generateSimpleDouble() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a temperature value of 72.5 degrees", + generating: SimpleDouble.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.temperature.isNaN, "[\(model.name)] temperature should be a valid number") + } + } + + @Test("Generate SimpleBool with all supported models") + func generateSimpleBool() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a boolean value: true", + generating: SimpleBool.self + ) + + logGenerated(response.content, model: model.name) + let jsonData = response.rawContent.jsonString.data(using: .utf8) + #expect(jsonData != nil, "[\(model.name)] rawContent should be valid UTF-8 JSON") + if let jsonData { + let json = try JSONSerialization.jsonObject(with: jsonData) + let dictionary = json as? [String: Any] + let boolValue = dictionary?["value"] as? Bool + #expect(boolValue != nil, "[\(model.name)] value should be encoded as a JSON boolean") + } + } + } + + @Test("Generate OptionalFields with all supported models") + func generateOptionalFields() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a person named Alex with nickname 'Lex'. Nickname may be omitted if unsure.", + generating: OptionalFields.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") + if let nickname = response.content.nickname { + #expect(!nickname.isEmpty, "[\(model.name)] nickname should not be empty when present") + } + } + } + + @Test("Generate Priority enum with all supported models") + func generatePriorityEnum() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a high priority value", + generating: Priority.self + ) + + logGenerated(response.content, model: model.name) + #expect( + [Priority.low, Priority.medium, Priority.high].contains(response.content), + "[\(model.name)] should generate valid priority" + ) + } + } + + @Test("Generate BasicStruct with all supported models") + func generateBasicStruct() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a person with name Alice, age 30, active status true, and score 95.5", + generating: BasicStruct.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") + #expect(response.content.age >= 0, "[\(model.name)] age should be non-negative") + } + } + + @Test("Generate nested struct (Person with Address) with all supported models") + func generateNestedStruct() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a person named John, age 25, living at 123 Main St, Springfield, 12345", + generating: Person.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") + #expect(response.content.age >= 0, "[\(model.name)] age should be non-negative") + #expect(!response.content.address.street.isEmpty, "[\(model.name)] street should not be empty") + #expect(!response.content.address.city.isEmpty, "[\(model.name)] city should not be empty") + } + } + + @Test("Generate struct with enum (TaskItem) with all supported models") + func generateStructWithEnum() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a task titled 'Complete project' with high priority, not completed", + generating: TaskItem.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.title.isEmpty, "[\(model.name)] title should not be empty") + #expect( + [Priority.low, Priority.medium, Priority.high].contains(response.content.priority), + "[\(model.name)] should have valid priority" + ) + } + } + + @Test("Generate simple array with all supported models") + func generateSimpleArray() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a list of 3 color names: red, green, blue", + generating: SimpleArray.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.colors.isEmpty, "[\(model.name)] colors should not be empty") + } + } + + @Test("Generate struct with array (MultiChoiceQuestion) with all supported models") + func generateStructWithArray() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: """ + Generate a quiz question: + - Question: What is the capital of France? + - Choices: London, Paris, Berlin, Madrid + - Answer: Paris + - Explanation: Paris is the capital city of France + """, + generating: MultiChoiceQuestion.self + ) + + logGenerated(response.content, model: model.name) + #expect(!response.content.text.isEmpty, "[\(model.name)] question text should not be empty") + #expect(response.content.choices.count == 4, "[\(model.name)] should have exactly 4 choices") + #expect(!response.content.answer.isEmpty, "[\(model.name)] answer should not be empty") + } + } +}