-
Notifications
You must be signed in to change notification settings - Fork 43
Implement structured output generation for both LlamaLanguageModel / MLXLanguageModel #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements structured output generation for LlamaLanguageModel and MLXLanguageModel by adding constrained token sampling to generate JSON that conforms to a schema. The implementation includes comprehensive tests covering various data types and structures.
Key changes:
- Added
ConstrainedJSONGeneratorthat uses token-level sampling to generate schema-conformant JSON - Implemented
TokenBackendprotocol with adapters for both Llama and MLX models - Enhanced
GenerationGuideto store constraint values for min/max on numbers and arrays - Extended
GenerationSchemawith character validation and schema prompt generation
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| Tests/AnyLanguageModelTests/StructuredGenerationTests.swift | Comprehensive test suite covering simple types, nested structs, enums, arrays, and optionals across all supported model types |
| Tests/AnyLanguageModelTests/GenerableMacroTests.swift | Added round-trip tests for enums, nested structs, and arrays |
| Sources/AnyLanguageModelMacros/GenerableMacro.swift | Refactored guide extraction to use a structured Constraints type and properly parse numeric ranges and array count constraints |
| Sources/AnyLanguageModel/StructuredGeneration.swift | New file implementing token-level constrained JSON generation with TokenBackend protocol and ConstrainedJSONGenerator |
| Sources/AnyLanguageModel/Models/SystemLanguageModel.swift | Updated to use schema-based generation for non-String types and added conversion to FoundationModels.DynamicGenerationSchema |
| Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | Implemented MLXTokenBackend and structured JSON generation with proper token sampling and repetition penalty handling |
| Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift | Implemented LlamaTokenBackend and structured JSON generation with batch-based decoding and sampler integration |
| Sources/AnyLanguageModel/GenerationSchema.swift | Added schemaPrompt() method, character validation for JSON strings, improved node equality checking, and support for constraint propagation |
| Sources/AnyLanguageModel/GenerationGuide.swift | Made GenerationGuide store actual constraint values (min/max, minCount/maxCount) for use during schema generation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| private static func buildValidIntegerTokens(backend: Backend) -> Set<Int> { | ||
| var allowed = Set<Int>() | ||
| for token in 0 ..< backend.vocabSize { |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buildValidIntegerTokens and buildValidDecimalTokens methods don't reserve capacity for the Set, unlike buildValidStringTokens which does. For large vocabularies (tens of thousands of tokens), this could cause multiple reallocations during Set growth. Consider adding allowed.reserveCapacity(backend.vocabSize / 10) or similar to improve performance.
| } | ||
|
|
||
| private mutating func generateArray(_ node: GenerationSchema.ArrayNode) throws -> String { | ||
| let count = node.minItems ?? node.maxItems ?? 4 |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The array count selection logic prioritizes minItems over maxItems when both are present. If minItems is set but is larger than maxItems, this will generate an array that violates the maxItems constraint. Consider validating that minItems <= maxItems, or choosing a count within the range (e.g., minItems when both are present, or a random value between them).
| guard self != "\\" else { return false } | ||
| guard let scalar = unicodeScalars.first, scalar.value >= 0x20 else { return false } |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JSON character validation excludes backslashes but doesn't handle other control characters that need escaping in JSON (like tabs, newlines, carriage returns when they appear literally in strings). JSON strings must escape control characters (0x00-0x1F) except for the allowed whitespace. Consider either allowing these characters to be properly escaped, or ensuring the validation is consistent with how tokens will actually be used in JSON string generation.
| if type != String.self { | ||
| let jsonString = try await generateStructuredJSON( | ||
| context: context, | ||
| session: session, | ||
| prompt: prompt, | ||
| schema: type.generationSchema, | ||
| options: options | ||
| ) | ||
| let generatedContent = try GeneratedContent(json: jsonString) | ||
| let content = try type.init(generatedContent) | ||
| return LanguageModelSession.Response( | ||
| content: content, | ||
| rawContent: generatedContent, | ||
| transcriptEntries: ArraySlice([]) | ||
| ) | ||
| } |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indentation is inconsistent here. This code block should be indented to match the surrounding method body. The if statement starting at line 70 should be aligned with the lines above it (lines 61-68).
| if type != String.self { | |
| let jsonString = try await generateStructuredJSON( | |
| context: context, | |
| session: session, | |
| prompt: prompt, | |
| schema: type.generationSchema, | |
| options: options | |
| ) | |
| let generatedContent = try GeneratedContent(json: jsonString) | |
| let content = try type.init(generatedContent) | |
| return LanguageModelSession.Response( | |
| content: content, | |
| rawContent: generatedContent, | |
| transcriptEntries: ArraySlice([]) | |
| ) | |
| } | |
| if type != String.self { | |
| let jsonString = try await generateStructuredJSON( | |
| context: context, | |
| session: session, | |
| prompt: prompt, | |
| schema: type.generationSchema, | |
| options: options | |
| ) | |
| let generatedContent = try GeneratedContent(json: jsonString) | |
| let content = try type.init(generatedContent) | |
| return LanguageModelSession.Response( | |
| content: content, | |
| rawContent: generatedContent, | |
| transcriptEntries: ArraySlice([]) | |
| ) | |
| } |
| private func clampNumberString(_ text: String, node: GenerationSchema.NumberNode) -> String { | ||
| if node.integerOnly { | ||
| let value = Int(text) ?? 0 | ||
| let clamped = clampInt(value, min: node.minimum, max: node.maximum) | ||
| return String(clamped) | ||
| } else { | ||
| let value = Double(text) ?? 0 | ||
| let clamped = clampDouble(value, min: node.minimum, max: node.maximum) | ||
| return formatDouble(clamped) |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number parsing logic uses Int(text) ?? 0 and Double(text) ?? 0, which means that any unparseable string (like "---", "...", or an empty string after token sampling) defaults to 0. This could mask generation errors. Consider validating the generated text is a valid number format before clamping, or throwing an error if the parsing fails.
| private func clampNumberString(_ text: String, node: GenerationSchema.NumberNode) -> String { | |
| if node.integerOnly { | |
| let value = Int(text) ?? 0 | |
| let clamped = clampInt(value, min: node.minimum, max: node.maximum) | |
| return String(clamped) | |
| } else { | |
| let value = Double(text) ?? 0 | |
| let clamped = clampDouble(value, min: node.minimum, max: node.maximum) | |
| return formatDouble(clamped) | |
| private func normalizedNumericText(_ text: String) -> String { | |
| text.trimmingCharacters(in: .whitespacesAndNewlines) | |
| } | |
| private func parseInteger(_ text: String) -> Int? { | |
| let normalized = normalizedNumericText(text) | |
| guard !normalized.isEmpty else { return nil } | |
| return Int(normalized) | |
| } | |
| private func parseDouble(_ text: String) -> Double? { | |
| let normalized = normalizedNumericText(text) | |
| guard !normalized.isEmpty else { return nil } | |
| return Double(normalized) | |
| } | |
| private func fallbackInteger(for node: GenerationSchema.NumberNode) -> Int { | |
| if let minimum = node.minimum { | |
| return Int(ceil(minimum)) | |
| } else if let maximum = node.maximum { | |
| return Int(floor(maximum)) | |
| } else { | |
| return 0 | |
| } | |
| } | |
| private func fallbackDouble(for node: GenerationSchema.NumberNode) -> Double { | |
| if let minimum = node.minimum { | |
| return minimum | |
| } else if let maximum = node.maximum { | |
| return maximum | |
| } else { | |
| return 0 | |
| } | |
| } | |
| private func clampNumberString(_ text: String, node: GenerationSchema.NumberNode) -> String { | |
| if node.integerOnly { | |
| if let value = parseInteger(text) { | |
| let clamped = clampInt(value, min: node.minimum, max: node.maximum) | |
| return String(clamped) | |
| } else { | |
| let fallback = fallbackInteger(for: node) | |
| let clamped = clampInt(fallback, min: node.minimum, max: node.maximum) | |
| return String(clamped) | |
| } | |
| } else { | |
| if let value = parseDouble(text) { | |
| let clamped = clampDouble(value, min: node.minimum, max: node.maximum) | |
| return formatDouble(clamped) | |
| } else { | |
| let fallback = fallbackDouble(for: node) | |
| let clamped = clampDouble(fallback, min: node.minimum, max: node.maximum) | |
| return formatDouble(clamped) | |
| } |
| } | ||
|
|
||
| private static func buildValidDecimalTokens(backend: Backend) -> Set<Int> { | ||
| var allowed = Set<Int>() |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buildValidDecimalTokens method doesn't reserve capacity for the Set, unlike buildValidStringTokens which does. For large vocabularies (tens of thousands of tokens), this could cause multiple reallocations during Set growth. Consider adding allowed.reserveCapacity(backend.vocabSize / 10) or similar to improve performance.
| var allowed = Set<Int>() | |
| var allowed = Set<Int>() | |
| allowed.reserveCapacity(backend.vocabSize / 10) |
| for tokenIndex in 0 ..< vocabSize { | ||
| if !allowedTokens.contains(tokenIndex) { | ||
| logits[tokenIndex] = -Float.infinity | ||
| } | ||
| } |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sample method modifies the logits array in place by setting disallowed tokens to negative infinity. Since llama_get_logits returns a pointer to internal model state, this modification could have unintended side effects if the logits buffer is reused. Consider creating a copy of the logits before modification, or ensure this is the intended behavior with llama.cpp.
|
|
||
| while backend.remainingTokens > 0, result.count < maxTokens { | ||
| let token = try backend.sample(from: allowedTokens) | ||
| if basicTerminators.contains(token) { break } | ||
|
|
||
| guard let text = backend.tokenText(token) else { break } | ||
| result += text |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number generation loop uses result.count to limit the number of characters rather than tracking token count. This means that a single token could exceed the maxTokens limit if it contains many characters. Consider tracking the number of tokens generated instead, similar to the generateFreeString method which uses a generated counter.
| while backend.remainingTokens > 0, result.count < maxTokens { | |
| let token = try backend.sample(from: allowedTokens) | |
| if basicTerminators.contains(token) { break } | |
| guard let text = backend.tokenText(token) else { break } | |
| result += text | |
| var generated = 0 | |
| while backend.remainingTokens > 0, generated < maxTokens { | |
| let token = try backend.sample(from: allowedTokens) | |
| if basicTerminators.contains(token) { break } | |
| guard let text = backend.tokenText(token) else { break } | |
| result += text | |
| generated += 1 |
|
@eastriverlee Thank you for your contribution! And thank you for your patience. Reviewing this now that I'm back from the holiday break. |
Related to #27