@@ -460,10 +460,12 @@ import Foundation
460460 llama_set_n_threads ( context, runtimeOptions. threads, runtimeOptions. threads)
461461
462462 let maxTokens = runtimeOptions. maximumResponseTokens ?? 100
463+ let fullPrompt = try formatPrompt ( for: session)
464+
463465 let text = try await generateText (
464466 context: context,
465467 model: model!,
466- prompt: prompt . description ,
468+ prompt: fullPrompt ,
467469 maxTokens: maxTokens,
468470 options: runtimeOptions
469471 )
@@ -518,12 +520,13 @@ import Foundation
518520 llama_set_n_threads ( context, runtimeOptions. threads, runtimeOptions. threads)
519521
520522 var accumulatedText = " "
523+ let fullPrompt = try self . formatPrompt ( for: session)
521524
522525 do {
523526 for try await tokenText in generateTextStream (
524527 context: context,
525528 model: model!,
526- prompt: prompt . description ,
529+ prompt: fullPrompt ,
527530 maxTokens: maxTokens,
528531 options: runtimeOptions
529532 ) {
@@ -942,6 +945,97 @@ import Foundation
942945
943946 // MARK: - Helper Methods
944947
948+ private func formatPrompt( for session: LanguageModelSession ) throws -> String {
949+ guard let model = self . model else {
950+ throw LlamaLanguageModelError . modelLoadFailed
951+ }
952+
953+ var messages : [ ( role: String , content: String ) ] = [ ]
954+
955+ for entry in session. transcript {
956+ switch entry {
957+ case . instructions( let instructions) :
958+ let text = extractText ( from: instructions. segments)
959+ if !text. isEmpty {
960+ messages. append ( ( " system " , text) )
961+ }
962+
963+ case . prompt( let prompt) :
964+ let text = extractText ( from: prompt. segments)
965+ if !text. isEmpty {
966+ messages. append ( ( " user " , text) )
967+ }
968+
969+ case . response( let response) :
970+ let text = extractText ( from: response. segments)
971+ if !text. isEmpty {
972+ messages. append ( ( " assistant " , text) )
973+ }
974+
975+ default :
976+ break
977+ }
978+ }
979+
980+ // Keep C strings alive while using them
981+ let cRoles = messages. map { strdup ( $0. role) }
982+ let cContents = messages. map { strdup ( $0. content) }
983+
984+ defer {
985+ cRoles. forEach { free ( $0) }
986+ cContents. forEach { free ( $0) }
987+ }
988+
989+ var cMessages = [ llama_chat_message] ( )
990+ for i in 0 ..< messages. count {
991+ cMessages. append ( llama_chat_message ( role: cRoles [ i] , content: cContents [ i] ) )
992+ }
993+
994+ // Get chat template embedded in the model's GGUF file (e.g., Llama 3, Mistral, ChatML)
995+ let tmpl = llama_model_chat_template ( model, nil )
996+
997+ // Get required buffer size
998+ let requiredSize = llama_chat_apply_template (
999+ tmpl,
1000+ cMessages,
1001+ cMessages. count,
1002+ true , // add_ass: Add assistant generation prompt
1003+ nil ,
1004+ 0
1005+ )
1006+
1007+ guard requiredSize > 0 else {
1008+ throw LlamaLanguageModelError . encodingFailed
1009+ }
1010+
1011+ // Allocate buffer and apply template
1012+ var buffer = [ CChar] ( repeating: 0 , count: Int ( requiredSize) + 1 )
1013+
1014+ let result = llama_chat_apply_template (
1015+ tmpl,
1016+ cMessages,
1017+ cMessages. count,
1018+ true ,
1019+ & buffer,
1020+ Int32 ( buffer. count)
1021+ )
1022+
1023+ guard result > 0 else {
1024+ throw LlamaLanguageModelError . encodingFailed
1025+ }
1026+
1027+ return buffer. withUnsafeBytes { rawBuffer in
1028+ String ( decoding: rawBuffer. prefix ( Int ( result) ) , as: UTF8 . self)
1029+ }
1030+ }
1031+
1032+ private func extractText( from segments: [ Transcript . Segment ] ) -> String {
1033+ segments. compactMap { segment -> String ? in
1034+ if case . text( let t) = segment { return t. content }
1035+ return nil
1036+ } . joined ( )
1037+ }
1038+
9451039 private func tokenizeText( vocab: OpaquePointer , text: String ) throws -> [ llama_token ] {
9461040 let utf8Count = text. utf8. count
9471041 let maxTokens = Int32 ( max ( utf8Count * 2 , 8 ) ) // Rough estimate, minimum capacity
0 commit comments