Skip to content

Commit 2488201

Browse files
committed
Respect schema prompt flag and enhance structured prompts with JSONSchema
1 parent 51294c6 commit 2488201

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Foundation
1111
#endif
1212

1313
#if MLX
14+
import JSONSchema
1415
import MLXLMCommon
1516
import MLX
1617
import MLXVLM
@@ -205,7 +206,8 @@ import Foundation
205206
session: session,
206207
prompt: prompt,
207208
schema: type.generationSchema,
208-
options: options
209+
options: options,
210+
includeSchemaInPrompt: includeSchemaInPrompt
209211
)
210212
let generatedContent = try GeneratedContent(json: jsonString)
211213
let content = try type.init(generatedContent)
@@ -662,6 +664,37 @@ import Foundation
662664
return textParts.joined(separator: "\n")
663665
}
664666

667+
private func schemaPrompt(for schema: GenerationSchema) -> String {
668+
let encoder = JSONEncoder()
669+
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
670+
guard
671+
let data = try? encoder.encode(schema),
672+
let jsonSchema = try? JSONDecoder().decode(JSONSchema.self, from: data),
673+
let schemaJSON = String(data: data, encoding: .utf8)
674+
else {
675+
return schema.schemaPrompt()
676+
}
677+
678+
var header = "Respond with valid JSON matching this \(jsonSchema.typeName) schema"
679+
if let description = jsonSchema.description, !description.isEmpty {
680+
header += " (\(description))"
681+
}
682+
683+
if let constValue = jsonSchema.const,
684+
let data = try? encoder.encode(constValue),
685+
let constString = String(data: data, encoding: .utf8)
686+
{
687+
header += ". Expected value: \(constString)"
688+
} else if let enumValues = jsonSchema.enum, !enumValues.isEmpty,
689+
let data = try? encoder.encode(JSONValue.array(enumValues)),
690+
let enumString = String(data: data, encoding: .utf8)
691+
{
692+
header += ". Allowed values: \(enumString)"
693+
}
694+
695+
return "\(header):\n\(schemaJSON)"
696+
}
697+
665698
// MARK: - Structured JSON Generation
666699

667700
private enum StructuredGenerationError: Error {
@@ -674,13 +707,15 @@ import Foundation
674707
session: LanguageModelSession,
675708
prompt: Prompt,
676709
schema: GenerationSchema,
677-
options: GenerationOptions
710+
options: GenerationOptions,
711+
includeSchemaInPrompt: Bool
678712
) async throws -> String {
679713
let maxTokens = options.maximumResponseTokens ?? 512
680714
let generateParameters = toStructuredGenerateParameters(options)
681715

682716
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
683-
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schema.schemaPrompt())
717+
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
718+
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)
684719
let userInput = MLXLMCommon.UserInput(
685720
chat: chat,
686721
processing: .init(resize: .init(width: 512, height: 512)),
@@ -703,8 +738,12 @@ import Foundation
703738

704739
private func normalizeChatForStructuredGeneration(
705740
_ chat: [MLXLMCommon.Chat.Message],
706-
schemaPrompt: String
741+
schemaPrompt: String?
707742
) -> [MLXLMCommon.Chat.Message] {
743+
guard let schemaPrompt, !schemaPrompt.isEmpty else {
744+
return chat
745+
}
746+
708747
var systemMessageParts: [String] = []
709748
systemMessageParts.append(schemaPrompt)
710749

0 commit comments

Comments
 (0)