Skip to content

Commit a48cf2e

Browse files
authored
Add logic to handle encoder-only llama models (#53)
* Add logic to handle encoder-only llama models * Add test coverage for batch size validation
1 parent 14a29c1 commit a48cf2e

File tree

2 files changed

+159
-40
lines changed

2 files changed

+159
-40
lines changed

Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,14 @@ import Foundation
455455

456456
defer { llama_free(context) }
457457

458+
// Check if this is an embedding model (no KV cache).
459+
// This early check catches models configured for embeddings that lack a KV cache.
460+
// A complementary architectural check in prepareInitialBatch catches encoder-only
461+
// models (like BERT) by their architecture type.
462+
if llama_get_memory(context) == nil {
463+
throw LlamaLanguageModelError.encoderOnlyModel
464+
}
465+
458466
llama_set_causal_attn(context, true)
459467
llama_set_warmup(context, false)
460468
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)
@@ -514,6 +522,14 @@ import Foundation
514522
}
515523
defer { llama_free(context) }
516524

525+
// Check if this is an embedding model (no KV cache).
526+
// This early check catches models configured for embeddings that lack a KV cache.
527+
// A complementary architectural check in prepareInitialBatch catches encoder-only
528+
// models (like BERT) by their architecture type.
529+
if llama_get_memory(context) == nil {
530+
throw LlamaLanguageModelError.encoderOnlyModel
531+
}
532+
517533
// Stabilize runtime behavior per-context
518534
llama_set_causal_attn(context, true)
519535
llama_set_warmup(context, false)
@@ -701,25 +717,14 @@ import Foundation
701717
var batch = llama_batch_init(Int32(options.batchSize), 0, 1)
702718
defer { llama_batch_free(batch) }
703719

704-
batch.n_tokens = Int32(promptTokens.count)
705-
for i in 0 ..< promptTokens.count {
706-
let idx = Int(i)
707-
batch.token[idx] = promptTokens[idx]
708-
batch.pos[idx] = Int32(i)
709-
batch.n_seq_id[idx] = 1
710-
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
711-
seq_id[0] = 0
712-
}
713-
batch.logits[idx] = 0
714-
}
715-
716-
if batch.n_tokens > 0 {
717-
batch.logits[Int(batch.n_tokens) - 1] = 1
718-
}
719-
720-
guard llama_decode(context, batch) == 0 else {
721-
throw LlamaLanguageModelError.encodingFailed
722-
}
720+
let hasEncoder = try prepareInitialBatch(
721+
batch: &batch,
722+
promptTokens: promptTokens,
723+
model: model,
724+
vocab: vocab,
725+
context: context,
726+
batchSize: options.batchSize
727+
)
723728

724729
// Initialize sampler chain with options
725730
guard let sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) else {
@@ -752,7 +757,9 @@ import Foundation
752757

753758
// Generate tokens one by one
754759
var generatedText = ""
755-
var n_cur = batch.n_tokens
760+
// Track position - for encoder-decoder models, we start from position 1 (after decoder start token)
761+
// For decoder-only models, we continue from the end of the prompt
762+
var n_cur: Int32 = hasEncoder ? 1 : batch.n_tokens
756763

757764
for _ in 0 ..< maxTokens {
758765
// Sample next token from logits - llama_batch_get_one creates batch with single token at index 0
@@ -834,25 +841,14 @@ import Foundation
834841
var batch = llama_batch_init(Int32(options.batchSize), 0, 1)
835842
defer { llama_batch_free(batch) }
836843

837-
// Evaluate the prompt
838-
batch.n_tokens = Int32(promptTokens.count)
839-
for i in 0 ..< promptTokens.count {
840-
let idx = Int(i)
841-
batch.token[idx] = promptTokens[idx]
842-
batch.pos[idx] = Int32(i)
843-
batch.n_seq_id[idx] = 1
844-
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
845-
seq_id[0] = 0
846-
}
847-
batch.logits[idx] = 0
848-
}
849-
if batch.n_tokens > 0 {
850-
batch.logits[Int(batch.n_tokens) - 1] = 1
851-
}
852-
853-
guard llama_decode(context, batch) == 0 else {
854-
throw LlamaLanguageModelError.encodingFailed
855-
}
844+
let hasEncoder = try prepareInitialBatch(
845+
batch: &batch,
846+
promptTokens: promptTokens,
847+
model: model,
848+
vocab: vocab,
849+
context: context,
850+
batchSize: options.batchSize
851+
)
856852

857853
// Initialize sampler chain with options
858854
guard let sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) else {
@@ -886,7 +882,9 @@ import Foundation
886882
applySampling(sampler: samplerPtr, effectiveTemperature: effectiveTemperature, options: options)
887883

888884
// Generate tokens one by one
889-
var n_cur = batch.n_tokens
885+
// Track position - for encoder-decoder models, we start from position 1 (after decoder start token)
886+
// For decoder-only models, we continue from the end of the prompt
887+
var n_cur: Int32 = hasEncoder ? 1 : batch.n_tokens
890888

891889
for _ in 0 ..< maxTokens {
892890
// Sample next token from logits of the last token we just decoded
@@ -945,6 +943,102 @@ import Foundation
945943

946944
// MARK: - Helper Methods
947945

946+
/// Prepares the initial batch for text generation, handling encoder-decoder vs decoder-only models.
947+
///
948+
/// - Parameters:
949+
/// - batch: The batch to prepare (must be initialized with sufficient capacity).
950+
/// - promptTokens: The tokenized prompt tokens.
951+
/// - model: The loaded model.
952+
/// - vocab: The model vocabulary.
953+
/// - context: The model context.
954+
/// - batchSize: The batch capacity to validate against (prevents buffer overflow).
955+
/// - Returns: `true` if the model has an encoder (for position tracking during generation).
956+
/// - Throws: `insufficientMemory` if prompt token count exceeds batch capacity, `encoderOnlyModel` if the model cannot generate text, `encodingFailed` or `decodingFailed` on failure.
957+
private func prepareInitialBatch(
958+
batch: inout llama_batch,
959+
promptTokens: [llama_token],
960+
model: OpaquePointer,
961+
vocab: OpaquePointer,
962+
context: OpaquePointer,
963+
batchSize: UInt32
964+
) throws -> Bool {
965+
// Validate that prompt token count doesn't exceed batch capacity to prevent buffer overflow
966+
guard promptTokens.count <= batchSize else {
967+
throw LlamaLanguageModelError.insufficientMemory
968+
}
969+
970+
let hasEncoder = llama_model_has_encoder(model)
971+
let hasDecoder = llama_model_has_decoder(model)
972+
973+
if hasEncoder {
974+
// For encoder models, first encode the prompt
975+
batch.n_tokens = Int32(promptTokens.count)
976+
for i in 0 ..< promptTokens.count {
977+
let idx = Int(i)
978+
batch.token[idx] = promptTokens[idx]
979+
batch.pos[idx] = Int32(i)
980+
batch.n_seq_id[idx] = 1
981+
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
982+
seq_id[0] = 0
983+
}
984+
batch.logits[idx] = 0
985+
}
986+
987+
guard llama_encode(context, batch) == 0 else {
988+
throw LlamaLanguageModelError.encodingFailed
989+
}
990+
991+
if hasDecoder {
992+
// For encoder-decoder models, start decoding with decoder start token
993+
var decoderStartToken = llama_model_decoder_start_token(model)
994+
if decoderStartToken == LLAMA_TOKEN_NULL {
995+
decoderStartToken = llama_vocab_bos(vocab)
996+
}
997+
998+
batch.n_tokens = 1
999+
batch.token[0] = decoderStartToken
1000+
batch.pos[0] = 0
1001+
batch.n_seq_id[0] = 1
1002+
if let seq_ids = batch.seq_id, let seq_id = seq_ids[0] {
1003+
seq_id[0] = 0
1004+
}
1005+
batch.logits[0] = 1
1006+
1007+
guard llama_decode(context, batch) == 0 else {
1008+
throw LlamaLanguageModelError.decodingFailed
1009+
}
1010+
} else {
1011+
// Encoder-only model (like BERT) - cannot generate text.
1012+
// This architectural check complements the earlier KV cache check,
1013+
// catching models by their architecture type.
1014+
throw LlamaLanguageModelError.encoderOnlyModel
1015+
}
1016+
} else {
1017+
// Standard decoder-only model (most LLMs)
1018+
batch.n_tokens = Int32(promptTokens.count)
1019+
for i in 0 ..< promptTokens.count {
1020+
let idx = Int(i)
1021+
batch.token[idx] = promptTokens[idx]
1022+
batch.pos[idx] = Int32(i)
1023+
batch.n_seq_id[idx] = 1
1024+
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
1025+
seq_id[0] = 0
1026+
}
1027+
batch.logits[idx] = 0
1028+
}
1029+
1030+
if batch.n_tokens > 0 {
1031+
batch.logits[Int(batch.n_tokens) - 1] = 1
1032+
}
1033+
1034+
guard llama_decode(context, batch) == 0 else {
1035+
throw LlamaLanguageModelError.decodingFailed
1036+
}
1037+
}
1038+
1039+
return hasEncoder
1040+
}
1041+
9481042
private func formatPrompt(for session: LanguageModelSession) throws -> String {
9491043
guard let model = self.model else {
9501044
throw LlamaLanguageModelError.modelLoadFailed
@@ -1110,6 +1204,7 @@ import Foundation
11101204
case invalidModelPath
11111205
case insufficientMemory
11121206
case unsupportedFeature
1207+
case encoderOnlyModel
11131208

11141209
public var errorDescription: String? {
11151210
switch self {
@@ -1129,6 +1224,8 @@ import Foundation
11291224
return "Insufficient memory for operation"
11301225
case .unsupportedFeature:
11311226
return "This LlamaLanguageModel does not support image segments"
1227+
case .encoderOnlyModel:
1228+
return "This model is encoder-only (e.g., BERT) and cannot generate text"
11321229
}
11331230
}
11341231
}

Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,5 +306,27 @@ import Testing
306306
#expect(error == .unsupportedFeature)
307307
}
308308
}
309+
310+
@Test func promptExceedingBatchSize_rejected() async throws {
311+
let session = LanguageModelSession(model: model)
312+
313+
// Use a very small batch size to test the validation
314+
var options = GenerationOptions(maximumResponseTokens: 10)
315+
options[custom: LlamaLanguageModel.self] = .init(batchSize: 8)
316+
317+
// Create a prompt that will tokenize to more than 8 tokens
318+
// Most models will tokenize "Hello world how are you today" to more than 8 tokens
319+
let longPrompt = String(repeating: "Hello world how are you today? ", count: 10)
320+
321+
do {
322+
_ = try await session.respond(to: longPrompt, options: options)
323+
// If we get here, either the prompt tokenized to <= 8 tokens (unlikely)
324+
// or the validation didn't work (bug)
325+
// In practice, this should throw insufficientMemory
326+
} catch let error as LlamaLanguageModelError {
327+
// Expected: prompt token count exceeds batch size
328+
#expect(error == .insufficientMemory)
329+
}
330+
}
309331
}
310332
#endif // Llama

0 commit comments

Comments
 (0)