@@ -457,10 +457,12 @@ import Foundation
457457 llama_set_n_threads ( context, runtimeOptions. threads, runtimeOptions. threads)
458458
459459 let maxTokens = runtimeOptions. maximumResponseTokens ?? 100
460+ let fullPrompt = try formatPrompt ( for: session)
461+
460462 let text = try await generateText (
461463 context: context,
462464 model: model!,
463- prompt: prompt . description ,
465+ prompt: fullPrompt ,
464466 maxTokens: maxTokens,
465467 options: runtimeOptions
466468 )
@@ -515,12 +517,13 @@ import Foundation
515517 llama_set_n_threads ( context, runtimeOptions. threads, runtimeOptions. threads)
516518
517519 var accumulatedText = " "
520+ let fullPrompt = try self . formatPrompt ( for: session)
518521
519522 do {
520523 for try await tokenText in generateTextStream (
521524 context: context,
522525 model: model!,
523- prompt: prompt . description ,
526+ prompt: fullPrompt ,
524527 maxTokens: maxTokens,
525528 options: runtimeOptions
526529 ) {
@@ -939,6 +942,97 @@ import Foundation
939942
940943 // MARK: - Helper Methods
941944
945+ private func formatPrompt( for session: LanguageModelSession ) throws -> String {
946+ guard let model = self . model else {
947+ throw LlamaLanguageModelError . modelLoadFailed
948+ }
949+
950+ var messages : [ ( role: String , content: String ) ] = [ ]
951+
952+ for entry in session. transcript {
953+ switch entry {
954+ case . instructions( let instructions) :
955+ let text = extractText ( from: instructions. segments)
956+ if !text. isEmpty {
957+ messages. append ( ( " system " , text) )
958+ }
959+
960+ case . prompt( let prompt) :
961+ let text = extractText ( from: prompt. segments)
962+ if !text. isEmpty {
963+ messages. append ( ( " user " , text) )
964+ }
965+
966+ case . response( let response) :
967+ let text = extractText ( from: response. segments)
968+ if !text. isEmpty {
969+ messages. append ( ( " assistant " , text) )
970+ }
971+
972+ default :
973+ break
974+ }
975+ }
976+
977+ // Keep C strings alive while using them
978+ let cRoles = messages. map { strdup ( $0. role) }
979+ let cContents = messages. map { strdup ( $0. content) }
980+
981+ defer {
982+ cRoles. forEach { free ( $0) }
983+ cContents. forEach { free ( $0) }
984+ }
985+
986+ var cMessages = [ llama_chat_message] ( )
987+ for i in 0 ..< messages. count {
988+ cMessages. append ( llama_chat_message ( role: cRoles [ i] , content: cContents [ i] ) )
989+ }
990+
991+ // Get chat template embedded in the model's GGUF file (e.g., Llama 3, Mistral, ChatML)
992+ let tmpl = llama_model_chat_template ( model, nil )
993+
994+ // Get required buffer size
995+ let requiredSize = llama_chat_apply_template (
996+ tmpl,
997+ cMessages,
998+ cMessages. count,
999+ true , // add_ass: Add assistant generation prompt
1000+ nil ,
1001+ 0
1002+ )
1003+
1004+ guard requiredSize > 0 else {
1005+ throw LlamaLanguageModelError . encodingFailed
1006+ }
1007+
1008+ // Allocate buffer and apply template
1009+ var buffer = [ CChar] ( repeating: 0 , count: Int ( requiredSize) + 1 )
1010+
1011+ let result = llama_chat_apply_template (
1012+ tmpl,
1013+ cMessages,
1014+ cMessages. count,
1015+ true ,
1016+ & buffer,
1017+ Int32 ( buffer. count)
1018+ )
1019+
1020+ guard result > 0 else {
1021+ throw LlamaLanguageModelError . encodingFailed
1022+ }
1023+
1024+ return buffer. withUnsafeBytes { rawBuffer in
1025+ String ( decoding: rawBuffer. prefix ( Int ( result) ) , as: UTF8 . self)
1026+ }
1027+ }
1028+
1029+ private func extractText( from segments: [ Transcript . Segment ] ) -> String {
1030+ segments. compactMap { segment -> String ? in
1031+ if case . text( let t) = segment { return t. content }
1032+ return nil
1033+ } . joined ( )
1034+ }
1035+
9421036 private func tokenizeText( vocab: OpaquePointer , text: String ) throws -> [ llama_token ] {
9431037 let utf8Count = text. utf8. count
9441038 let maxTokens = Int32 ( max ( utf8Count * 2 , 8 ) ) // Rough estimate, minimum capacity
0 commit comments