@@ -11,6 +11,7 @@ import Foundation
1111#endif
1212
1313#if MLX
14+ import JSONSchema
1415 import MLXLMCommon
1516 import MLX
1617 import MLXVLM
@@ -205,7 +206,8 @@ import Foundation
205206 session: session,
206207 prompt: prompt,
207208 schema: type. generationSchema,
208- options: options
209+ options: options,
210+ includeSchemaInPrompt: includeSchemaInPrompt
209211 )
210212 let generatedContent = try GeneratedContent ( json: jsonString)
211213 let content = try type. init ( generatedContent)
@@ -662,6 +664,37 @@ import Foundation
662664 return textParts. joined ( separator: " \n " )
663665 }
664666
667+ private func schemaPrompt( for schema: GenerationSchema ) -> String {
668+ let encoder = JSONEncoder ( )
669+ encoder. outputFormatting = [ . prettyPrinted, . sortedKeys]
670+ guard
671+ let data = try ? encoder. encode ( schema) ,
672+ let jsonSchema = try ? JSONDecoder ( ) . decode ( JSONSchema . self, from: data) ,
673+ let schemaJSON = String ( data: data, encoding: . utf8)
674+ else {
675+ return schema. schemaPrompt ( )
676+ }
677+
678+ var header = " Respond with valid JSON matching this \( jsonSchema. typeName) schema "
679+ if let description = jsonSchema. description, !description. isEmpty {
680+ header += " ( \( description) ) "
681+ }
682+
683+ if let constValue = jsonSchema. const,
684+ let data = try ? encoder. encode ( constValue) ,
685+ let constString = String ( data: data, encoding: . utf8)
686+ {
687+ header += " . Expected value: \( constString) "
688+ } else if let enumValues = jsonSchema. enum, !enumValues. isEmpty,
689+ let data = try ? encoder. encode ( JSONValue . array ( enumValues) ) ,
690+ let enumString = String ( data: data, encoding: . utf8)
691+ {
692+ header += " . Allowed values: \( enumString) "
693+ }
694+
695+ return " \( header) : \n \( schemaJSON) "
696+ }
697+
665698 // MARK: - Structured JSON Generation
666699
667700 private enum StructuredGenerationError : Error {
@@ -674,13 +707,15 @@ import Foundation
674707 session: LanguageModelSession ,
675708 prompt: Prompt ,
676709 schema: GenerationSchema ,
677- options: GenerationOptions
710+ options: GenerationOptions ,
711+ includeSchemaInPrompt: Bool
678712 ) async throws -> String {
679713 let maxTokens = options. maximumResponseTokens ?? 512
680714 let generateParameters = toStructuredGenerateParameters ( options)
681715
682716 let baseChat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
683- let chat = normalizeChatForStructuredGeneration ( baseChat, schemaPrompt: schema. schemaPrompt ( ) )
717+ let schemaPrompt = includeSchemaInPrompt ? schemaPrompt ( for: schema) : nil
718+ let chat = normalizeChatForStructuredGeneration ( baseChat, schemaPrompt: schemaPrompt)
684719 let userInput = MLXLMCommon . UserInput (
685720 chat: chat,
686721 processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
@@ -703,8 +738,12 @@ import Foundation
703738
704739 private func normalizeChatForStructuredGeneration(
705740 _ chat: [ MLXLMCommon . Chat . Message ] ,
706- schemaPrompt: String
741+ schemaPrompt: String ?
707742 ) -> [ MLXLMCommon . Chat . Message ] {
743+ guard let schemaPrompt, !schemaPrompt. isEmpty else {
744+ return chat
745+ }
746+
708747 var systemMessageParts : [ String ] = [ ]
709748 systemMessageParts. append ( schemaPrompt)
710749
0 commit comments