Skip to content
51 changes: 33 additions & 18 deletions Sources/AnyLanguageModel/GenerationGuide.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@ import struct Foundation.Decimal
import class Foundation.NSDecimalNumber

/// Guides that control how values are generated.
public struct GenerationGuide<Value> {}
public struct GenerationGuide<Value>: Sendable {
package var minimumCount: Int?
package var maximumCount: Int?
package var minimum: Double?
package var maximum: Double?

public init() {}

package init(minimumCount: Int?, maximumCount: Int?) {
self.minimumCount = minimumCount
self.maximumCount = maximumCount
}

package init(minimum: Double?, maximum: Double?) {
self.minimum = minimum
self.maximum = maximum
}
}

// MARK: - String Guides

Expand Down Expand Up @@ -45,7 +62,7 @@ extension GenerationGuide where Value == Int {
/// }
/// ```
public static func minimum(_ value: Int) -> GenerationGuide<Int> {
GenerationGuide<Int>()
GenerationGuide<Int>(minimum: Double(value), maximum: nil)
}

/// Enforces a maximum value.
Expand All @@ -65,7 +82,7 @@ extension GenerationGuide where Value == Int {
/// }
/// ```
public static func maximum(_ value: Int) -> GenerationGuide<Int> {
GenerationGuide<Int>()
GenerationGuide<Int>(minimum: nil, maximum: Double(value))
}

/// Enforces values fall within a range.
Expand All @@ -85,7 +102,7 @@ extension GenerationGuide where Value == Int {
/// }
/// ```
public static func range(_ range: ClosedRange<Int>) -> GenerationGuide<Int> {
GenerationGuide<Int>()
GenerationGuide<Int>(minimum: Double(range.lowerBound), maximum: Double(range.upperBound))
}
}

Expand Down Expand Up @@ -144,18 +161,18 @@ extension GenerationGuide where Value == Double {
/// Enforces a minimum value.
/// The bounds are inclusive.
public static func minimum(_ value: Double) -> GenerationGuide<Double> {
GenerationGuide<Double>()
GenerationGuide<Double>(minimum: value, maximum: nil)
}

/// Enforces a maximum value.
/// The bounds are inclusive.
public static func maximum(_ value: Double) -> GenerationGuide<Double> {
GenerationGuide<Double>()
GenerationGuide<Double>(minimum: nil, maximum: value)
}

/// Enforces values fall within a range.
public static func range(_ range: ClosedRange<Double>) -> GenerationGuide<Double> {
GenerationGuide<Double>()
GenerationGuide<Double>(minimum: range.lowerBound, maximum: range.upperBound)
}
}

Expand All @@ -168,33 +185,31 @@ extension GenerationGuide {
/// The bounds are inclusive.
public static func minimumCount<Element>(_ count: Int) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: count, maximumCount: nil)
}

/// Enforces a maximum number of elements in the array.
///
/// The bounds are inclusive.
public static func maximumCount<Element>(_ count: Int) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: nil, maximumCount: count)
}

/// Enforces that the number of elements in the array fall within a closed range.
public static func count<Element>(_ range: ClosedRange<Int>) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: range.lowerBound, maximumCount: range.upperBound)
}

/// Enforces that the array has exactly a certain number elements.
public static func count<Element>(_ count: Int) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: count, maximumCount: count)
}

/// Enforces a guide on the elements within the array.
public static func element<Element>(_ guide: GenerationGuide<Element>) -> GenerationGuide<
[Element]
>
public static func element<Element>(_ guide: GenerationGuide<Element>) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
}
Expand All @@ -210,7 +225,7 @@ extension GenerationGuide where Value == [Never] {
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.minimumCount(_:)` on your own.
public static func minimumCount(_ count: Int) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: count, maximumCount: nil)
}

/// Enforces a maximum number of elements in the array.
Expand All @@ -219,20 +234,20 @@ extension GenerationGuide where Value == [Never] {
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.maximumCount(_:)` on your own.
public static func maximumCount(_ count: Int) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: nil, maximumCount: count)
}

/// Enforces that the number of elements in the array fall within a closed range.
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own.
public static func count(_ range: ClosedRange<Int>) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: range.lowerBound, maximumCount: range.upperBound)
}

/// Enforces that the array has exactly a certain number elements.
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own.
public static func count(_ count: Int) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: count, maximumCount: count)
}
}
97 changes: 93 additions & 4 deletions Sources/AnyLanguageModel/GenerationSchema.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible
)
}
}

var nodeDescription: String? {
switch self {
case .object(let node): node.description
case .array(let node): node.description
case .string(let node): node.description
case .number(let node): node.description
case .boolean, .anyOf, .ref: nil
}
}
}

struct ObjectNode: Sendable, Codable {
Expand Down Expand Up @@ -204,7 +214,7 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible
}

let root: Node
private var defs: [String: Node]
var defs: [String: Node]

/// A string representation of the debug description.
///
Expand Down Expand Up @@ -504,12 +514,32 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible
}

private static func nodesEqual(_ a: Node, _ b: Node) -> Bool {
// Simple structural equality - could be enhanced
switch (a, b) {
case (.boolean, .boolean):
return true
case (.ref(let aName), .ref(let bName)):
return aName == bName
case (.string(let aString), .string(let bString)):
return aString.pattern == bString.pattern
&& aString.enumChoices == bString.enumChoices
case (.number(let aNumber), .number(let bNumber)):
return aNumber.integerOnly == bNumber.integerOnly
&& aNumber.minimum == bNumber.minimum
&& aNumber.maximum == bNumber.maximum
case (.array(let aArray), .array(let bArray)):
return aArray.minItems == bArray.minItems
&& aArray.maxItems == bArray.maxItems
&& nodesEqual(aArray.items, bArray.items)
case (.object(let aObject), .object(let bObject)):
return aObject.required == bObject.required
&& aObject.properties.keys == bObject.properties.keys
&& aObject.properties.allSatisfy { key, aNode in
guard let bNode = bObject.properties[key] else { return false }
return nodesEqual(aNode, bNode)
}
case (.anyOf(let aNodes), .anyOf(let bNodes)):
return aNodes.count == bNodes.count
&& zip(aNodes, bNodes).allSatisfy(nodesEqual)
default:
return false
}
Expand Down Expand Up @@ -693,16 +723,39 @@ extension GenerationSchema {
} else if type == String.self {
return (.string(StringNode(description: description, pattern: nil, enumChoices: nil)), [:])
} else if type == Int.self {
var minimum: Double?
var maximum: Double?
for guide in guides {
if let min = guide.minimum { minimum = min }
if let max = guide.maximum { maximum = max }
}
return (
.number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: true)), [:]
.number(NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: true)), [:]
)
} else if type == Float.self || type == Double.self || type == Decimal.self {
var minimum: Double?
var maximum: Double?
for guide in guides {
if let min = guide.minimum { minimum = min }
if let max = guide.maximum { maximum = max }
}
return (
.number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: false)), [:]
.number(NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: false)), [:]
)
} else {
// Complex type - use its schema
let schema = Value.generationSchema

// Arrays should be inlined, not referenced
if case .array(var arrayNode) = schema.root {
arrayNode.description = description
for guide in guides {
if let min = guide.minimumCount { arrayNode.minItems = min }
if let max = guide.maximumCount { arrayNode.maxItems = max }
}
return (.array(arrayNode), schema.defs)
}

let typeName = String(reflecting: Value.self)

var deps = schema.defs
Expand Down Expand Up @@ -800,4 +853,40 @@ extension GenerationSchema {
/// let data = try encoder.encode(schema)
/// ```
static let omitAdditionalPropertiesKey = CodingUserInfoKey(rawValue: "GenerationSchema.omitAdditionalProperties")!

package func schemaPrompt() -> String {
let encoder = JSONEncoder()
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
guard let data = try? encoder.encode(self),
let schemaJSON = String(data: data, encoding: .utf8) else {
return "Respond with valid JSON only."
}
return "Respond with valid JSON matching this schema:\n\(schemaJSON)"
}
}

extension Character {
package static let jsonQuoteScalars: Set<UInt32> = [0x22, 0x201C, 0x201D, 0x2018, 0x2019]
package static let jsonAllowedWhitespaceCharacters: Set<Character> = [" ", "\t", "\n"]

package var containsEmojiScalar: Bool {
unicodeScalars.contains { scalar in
scalar.properties.isEmojiPresentation || scalar.properties.isEmoji
}
}

package var isValidJSONStringCharacter: Bool {
guard self != "\\" else { return false }
guard let scalar = unicodeScalars.first, scalar.value >= 0x20 else { return false }
Comment on lines +879 to +880
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JSON character validation excludes backslashes but doesn't handle other control characters that need escaping in JSON (like tabs, newlines, carriage returns when they appear literally in strings). JSON strings must escape control characters (0x00-0x1F) except for the allowed whitespace. Consider either allowing these characters to be properly escaped, or ensuring the validation is consistent with how tokens will actually be used in JSON string generation.

Copilot uses AI. Check for mistakes.
guard !Self.jsonQuoteScalars.contains(scalar.value) else { return false }

if let ascii = asciiValue {
let char = Character(UnicodeScalar(ascii))
if Self.jsonAllowedWhitespaceCharacters.contains(char) { return true }
return isLetter || isNumber || (isASCII && (isPunctuation || isSymbol))
}

// Allow non-ASCII letters/numbers and emoji, but disallow non-ASCII punctuation (e.g. "】")
return isLetter || isNumber || containsEmojiScalar
}
}
Loading