Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,12 @@ import Foundation
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)

let maxTokens = runtimeOptions.maximumResponseTokens ?? 100
let fullPrompt = try formatPrompt(for: session)

let text = try await generateText(
context: context,
model: model!,
prompt: prompt.description,
prompt: fullPrompt,
maxTokens: maxTokens,
options: runtimeOptions
)
Expand Down Expand Up @@ -518,12 +520,13 @@ import Foundation
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)

var accumulatedText = ""
let fullPrompt = try self.formatPrompt(for: session)

do {
for try await tokenText in generateTextStream(
context: context,
model: model!,
prompt: prompt.description,
prompt: fullPrompt,
maxTokens: maxTokens,
options: runtimeOptions
) {
Expand Down Expand Up @@ -942,6 +945,97 @@ import Foundation

// MARK: - Helper Methods

private func formatPrompt(for session: LanguageModelSession) throws -> String {
guard let model = self.model else {
throw LlamaLanguageModelError.modelLoadFailed
}

var messages: [(role: String, content: String)] = []

for entry in session.transcript {
switch entry {
case .instructions(let instructions):
let text = extractText(from: instructions.segments)
if !text.isEmpty {
messages.append(("system", text))
}

case .prompt(let prompt):
let text = extractText(from: prompt.segments)
if !text.isEmpty {
messages.append(("user", text))
}

case .response(let response):
let text = extractText(from: response.segments)
if !text.isEmpty {
messages.append(("assistant", text))
}

default:
break
}
}

// Keep C strings alive while using them
let cRoles = messages.map { strdup($0.role) }
let cContents = messages.map { strdup($0.content) }

defer {
cRoles.forEach { free($0) }
cContents.forEach { free($0) }
}

var cMessages = [llama_chat_message]()
for i in 0 ..< messages.count {
cMessages.append(llama_chat_message(role: cRoles[i], content: cContents[i]))
}

// Get chat template embedded in the model's GGUF file (e.g., Llama 3, Mistral, ChatML)
let tmpl = llama_model_chat_template(model, nil)

// Get required buffer size
let requiredSize = llama_chat_apply_template(
tmpl,
cMessages,
cMessages.count,
true, // add_ass: Add assistant generation prompt
nil,
0
)

guard requiredSize > 0 else {
throw LlamaLanguageModelError.encodingFailed
}

// Allocate buffer and apply template
var buffer = [CChar](repeating: 0, count: Int(requiredSize) + 1)

let result = llama_chat_apply_template(
tmpl,
cMessages,
cMessages.count,
true,
&buffer,
Int32(buffer.count)
)

guard result > 0 else {
throw LlamaLanguageModelError.encodingFailed
}

return buffer.withUnsafeBytes { rawBuffer in
String(decoding: rawBuffer.prefix(Int(result)), as: UTF8.self)
}
}

private func extractText(from segments: [Transcript.Segment]) -> String {
segments.compactMap { segment -> String? in
if case .text(let t) = segment { return t.content }
return nil
}.joined()
}

private func tokenizeText(vocab: OpaquePointer, text: String) throws -> [llama_token] {
let utf8Count = text.utf8.count
let maxTokens = Int32(max(utf8Count * 2, 8)) // Rough estimate, minimum capacity
Expand Down