Skip to content

Commit 39a76c3

Browse files
committed
Implement chat templates for LlamaLanguageModel
1 parent cb7e45e commit 39a76c3

File tree

1 file changed

+96
-2
lines changed

1 file changed

+96
-2
lines changed

Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,12 @@ import Foundation
457457
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)
458458

459459
let maxTokens = runtimeOptions.maximumResponseTokens ?? 100
460+
let fullPrompt = try formatPrompt(for: session)
461+
460462
let text = try await generateText(
461463
context: context,
462464
model: model!,
463-
prompt: prompt.description,
465+
prompt: fullPrompt,
464466
maxTokens: maxTokens,
465467
options: runtimeOptions
466468
)
@@ -515,12 +517,13 @@ import Foundation
515517
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)
516518

517519
var accumulatedText = ""
520+
let fullPrompt = try self.formatPrompt(for: session)
518521

519522
do {
520523
for try await tokenText in generateTextStream(
521524
context: context,
522525
model: model!,
523-
prompt: prompt.description,
526+
prompt: fullPrompt,
524527
maxTokens: maxTokens,
525528
options: runtimeOptions
526529
) {
@@ -939,6 +942,97 @@ import Foundation
939942

940943
// MARK: - Helper Methods
941944

945+
private func formatPrompt(for session: LanguageModelSession) throws -> String {
946+
guard let model = self.model else {
947+
throw LlamaLanguageModelError.modelLoadFailed
948+
}
949+
950+
var messages: [(role: String, content: String)] = []
951+
952+
for entry in session.transcript {
953+
switch entry {
954+
case .instructions(let instructions):
955+
let text = extractText(from: instructions.segments)
956+
if !text.isEmpty {
957+
messages.append(("system", text))
958+
}
959+
960+
case .prompt(let prompt):
961+
let text = extractText(from: prompt.segments)
962+
if !text.isEmpty {
963+
messages.append(("user", text))
964+
}
965+
966+
case .response(let response):
967+
let text = extractText(from: response.segments)
968+
if !text.isEmpty {
969+
messages.append(("assistant", text))
970+
}
971+
972+
default:
973+
break
974+
}
975+
}
976+
977+
// Keep C strings alive while using them
978+
let cRoles = messages.map { strdup($0.role) }
979+
let cContents = messages.map { strdup($0.content) }
980+
981+
defer {
982+
cRoles.forEach { free($0) }
983+
cContents.forEach { free($0) }
984+
}
985+
986+
var cMessages = [llama_chat_message]()
987+
for i in 0 ..< messages.count {
988+
cMessages.append(llama_chat_message(role: cRoles[i], content: cContents[i]))
989+
}
990+
991+
// Get chat template embedded in the model's GGUF file (e.g., Llama 3, Mistral, ChatML)
992+
let tmpl = llama_model_chat_template(model, nil)
993+
994+
// Get required buffer size
995+
let requiredSize = llama_chat_apply_template(
996+
tmpl,
997+
cMessages,
998+
cMessages.count,
999+
true, // add_ass: Add assistant generation prompt
1000+
nil,
1001+
0
1002+
)
1003+
1004+
guard requiredSize > 0 else {
1005+
throw LlamaLanguageModelError.encodingFailed
1006+
}
1007+
1008+
// Allocate buffer and apply template
1009+
var buffer = [CChar](repeating: 0, count: Int(requiredSize) + 1)
1010+
1011+
let result = llama_chat_apply_template(
1012+
tmpl,
1013+
cMessages,
1014+
cMessages.count,
1015+
true,
1016+
&buffer,
1017+
Int32(buffer.count)
1018+
)
1019+
1020+
guard result > 0 else {
1021+
throw LlamaLanguageModelError.encodingFailed
1022+
}
1023+
1024+
return buffer.withUnsafeBytes { rawBuffer in
1025+
String(decoding: rawBuffer.prefix(Int(result)), as: UTF8.self)
1026+
}
1027+
}
1028+
1029+
private func extractText(from segments: [Transcript.Segment]) -> String {
1030+
segments.compactMap { segment -> String? in
1031+
if case .text(let t) = segment { return t.content }
1032+
return nil
1033+
}.joined()
1034+
}
1035+
9421036
private func tokenizeText(vocab: OpaquePointer, text: String) throws -> [llama_token] {
9431037
let utf8Count = text.utf8.count
9441038
let maxTokens = Int32(max(utf8Count * 2, 8)) // Rough estimate, minimum capacity

0 commit comments

Comments
 (0)