Skip to content

Commit acafecf

Browse files
committed
Implement chat templates for LlamaLanguageModel
1 parent 1fa6629 commit acafecf

File tree

1 file changed

+96
-2
lines changed

1 file changed

+96
-2
lines changed

Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,12 @@ import Foundation
460460
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)
461461

462462
let maxTokens = runtimeOptions.maximumResponseTokens ?? 100
463+
let fullPrompt = try formatPrompt(for: session)
464+
463465
let text = try await generateText(
464466
context: context,
465467
model: model!,
466-
prompt: prompt.description,
468+
prompt: fullPrompt,
467469
maxTokens: maxTokens,
468470
options: runtimeOptions
469471
)
@@ -518,12 +520,13 @@ import Foundation
518520
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)
519521

520522
var accumulatedText = ""
523+
let fullPrompt = try self.formatPrompt(for: session)
521524

522525
do {
523526
for try await tokenText in generateTextStream(
524527
context: context,
525528
model: model!,
526-
prompt: prompt.description,
529+
prompt: fullPrompt,
527530
maxTokens: maxTokens,
528531
options: runtimeOptions
529532
) {
@@ -942,6 +945,97 @@ import Foundation
942945

943946
// MARK: - Helper Methods
944947

948+
private func formatPrompt(for session: LanguageModelSession) throws -> String {
949+
guard let model = self.model else {
950+
throw LlamaLanguageModelError.modelLoadFailed
951+
}
952+
953+
var messages: [(role: String, content: String)] = []
954+
955+
for entry in session.transcript {
956+
switch entry {
957+
case .instructions(let instructions):
958+
let text = extractText(from: instructions.segments)
959+
if !text.isEmpty {
960+
messages.append(("system", text))
961+
}
962+
963+
case .prompt(let prompt):
964+
let text = extractText(from: prompt.segments)
965+
if !text.isEmpty {
966+
messages.append(("user", text))
967+
}
968+
969+
case .response(let response):
970+
let text = extractText(from: response.segments)
971+
if !text.isEmpty {
972+
messages.append(("assistant", text))
973+
}
974+
975+
default:
976+
break
977+
}
978+
}
979+
980+
// Keep C strings alive while using them
981+
let cRoles = messages.map { strdup($0.role) }
982+
let cContents = messages.map { strdup($0.content) }
983+
984+
defer {
985+
cRoles.forEach { free($0) }
986+
cContents.forEach { free($0) }
987+
}
988+
989+
var cMessages = [llama_chat_message]()
990+
for i in 0 ..< messages.count {
991+
cMessages.append(llama_chat_message(role: cRoles[i], content: cContents[i]))
992+
}
993+
994+
// Get chat template embedded in the model's GGUF file (e.g., Llama 3, Mistral, ChatML)
995+
let tmpl = llama_model_chat_template(model, nil)
996+
997+
// Get required buffer size
998+
let requiredSize = llama_chat_apply_template(
999+
tmpl,
1000+
cMessages,
1001+
cMessages.count,
1002+
true, // add_ass: Add assistant generation prompt
1003+
nil,
1004+
0
1005+
)
1006+
1007+
guard requiredSize > 0 else {
1008+
throw LlamaLanguageModelError.encodingFailed
1009+
}
1010+
1011+
// Allocate buffer and apply template
1012+
var buffer = [CChar](repeating: 0, count: Int(requiredSize) + 1)
1013+
1014+
let result = llama_chat_apply_template(
1015+
tmpl,
1016+
cMessages,
1017+
cMessages.count,
1018+
true,
1019+
&buffer,
1020+
Int32(buffer.count)
1021+
)
1022+
1023+
guard result > 0 else {
1024+
throw LlamaLanguageModelError.encodingFailed
1025+
}
1026+
1027+
return buffer.withUnsafeBytes { rawBuffer in
1028+
String(decoding: rawBuffer.prefix(Int(result)), as: UTF8.self)
1029+
}
1030+
}
1031+
1032+
private func extractText(from segments: [Transcript.Segment]) -> String {
1033+
segments.compactMap { segment -> String? in
1034+
if case .text(let t) = segment { return t.content }
1035+
return nil
1036+
}.joined()
1037+
}
1038+
9451039
private func tokenizeText(vocab: OpaquePointer, text: String) throws -> [llama_token] {
9461040
let utf8Count = text.utf8.count
9471041
let maxTokens = Int32(max(utf8Count * 2, 8)) // Rough estimate, minimum capacity

0 commit comments

Comments
 (0)