@@ -455,6 +455,14 @@ import Foundation
455455
456456 defer { llama_free ( context) }
457457
458+ // Check if this is an embedding model (no KV cache).
459+ // This early check catches models configured for embeddings that lack a KV cache.
460+ // A complementary architectural check in prepareInitialBatch catches encoder-only
461+ // models (like BERT) by their architecture type.
462+ if llama_get_memory ( context) == nil {
463+ throw LlamaLanguageModelError . encoderOnlyModel
464+ }
465+
458466 llama_set_causal_attn ( context, true )
459467 llama_set_warmup ( context, false )
460468 llama_set_n_threads ( context, runtimeOptions. threads, runtimeOptions. threads)
@@ -514,6 +522,14 @@ import Foundation
514522 }
515523 defer { llama_free ( context) }
516524
525+ // Check if this is an embedding model (no KV cache).
526+ // This early check catches models configured for embeddings that lack a KV cache.
527+ // A complementary architectural check in prepareInitialBatch catches encoder-only
528+ // models (like BERT) by their architecture type.
529+ if llama_get_memory ( context) == nil {
530+ throw LlamaLanguageModelError . encoderOnlyModel
531+ }
532+
517533 // Stabilize runtime behavior per-context
518534 llama_set_causal_attn ( context, true )
519535 llama_set_warmup ( context, false )
@@ -701,25 +717,14 @@ import Foundation
701717 var batch = llama_batch_init ( Int32 ( options. batchSize) , 0 , 1 )
702718 defer { llama_batch_free ( batch) }
703719
704- batch. n_tokens = Int32 ( promptTokens. count)
705- for i in 0 ..< promptTokens. count {
706- let idx = Int ( i)
707- batch. token [ idx] = promptTokens [ idx]
708- batch. pos [ idx] = Int32 ( i)
709- batch. n_seq_id [ idx] = 1
710- if let seq_ids = batch. seq_id, let seq_id = seq_ids [ idx] {
711- seq_id [ 0 ] = 0
712- }
713- batch. logits [ idx] = 0
714- }
715-
716- if batch. n_tokens > 0 {
717- batch. logits [ Int ( batch. n_tokens) - 1 ] = 1
718- }
719-
720- guard llama_decode ( context, batch) == 0 else {
721- throw LlamaLanguageModelError . encodingFailed
722- }
720+ let hasEncoder = try prepareInitialBatch (
721+ batch: & batch,
722+ promptTokens: promptTokens,
723+ model: model,
724+ vocab: vocab,
725+ context: context,
726+ batchSize: options. batchSize
727+ )
723728
724729 // Initialize sampler chain with options
725730 guard let sampler = llama_sampler_chain_init ( llama_sampler_chain_default_params ( ) ) else {
@@ -752,7 +757,9 @@ import Foundation
752757
753758 // Generate tokens one by one
754759 var generatedText = " "
755- var n_cur = batch. n_tokens
760+ // Track position - for encoder-decoder models, we start from position 1 (after decoder start token)
761+ // For decoder-only models, we continue from the end of the prompt
762+ var n_cur : Int32 = hasEncoder ? 1 : batch. n_tokens
756763
757764 for _ in 0 ..< maxTokens {
758765 // Sample next token from logits - llama_batch_get_one creates batch with single token at index 0
@@ -834,25 +841,14 @@ import Foundation
834841 var batch = llama_batch_init ( Int32 ( options. batchSize) , 0 , 1 )
835842 defer { llama_batch_free ( batch) }
836843
837- // Evaluate the prompt
838- batch. n_tokens = Int32 ( promptTokens. count)
839- for i in 0 ..< promptTokens. count {
840- let idx = Int ( i)
841- batch. token [ idx] = promptTokens [ idx]
842- batch. pos [ idx] = Int32 ( i)
843- batch. n_seq_id [ idx] = 1
844- if let seq_ids = batch. seq_id, let seq_id = seq_ids [ idx] {
845- seq_id [ 0 ] = 0
846- }
847- batch. logits [ idx] = 0
848- }
849- if batch. n_tokens > 0 {
850- batch. logits [ Int ( batch. n_tokens) - 1 ] = 1
851- }
852-
853- guard llama_decode ( context, batch) == 0 else {
854- throw LlamaLanguageModelError . encodingFailed
855- }
844+ let hasEncoder = try prepareInitialBatch (
845+ batch: & batch,
846+ promptTokens: promptTokens,
847+ model: model,
848+ vocab: vocab,
849+ context: context,
850+ batchSize: options. batchSize
851+ )
856852
857853 // Initialize sampler chain with options
858854 guard let sampler = llama_sampler_chain_init ( llama_sampler_chain_default_params ( ) ) else {
@@ -886,7 +882,9 @@ import Foundation
886882 applySampling ( sampler: samplerPtr, effectiveTemperature: effectiveTemperature, options: options)
887883
888884 // Generate tokens one by one
889- var n_cur = batch. n_tokens
885+ // Track position - for encoder-decoder models, we start from position 1 (after decoder start token)
886+ // For decoder-only models, we continue from the end of the prompt
887+ var n_cur : Int32 = hasEncoder ? 1 : batch. n_tokens
890888
891889 for _ in 0 ..< maxTokens {
892890 // Sample next token from logits of the last token we just decoded
@@ -945,6 +943,102 @@ import Foundation
945943
946944 // MARK: - Helper Methods
947945
946+ /// Prepares the initial batch for text generation, handling encoder-decoder vs decoder-only models.
947+ ///
948+ /// - Parameters:
949+ /// - batch: The batch to prepare (must be initialized with sufficient capacity).
950+ /// - promptTokens: The tokenized prompt tokens.
951+ /// - model: The loaded model.
952+ /// - vocab: The model vocabulary.
953+ /// - context: The model context.
954+ /// - batchSize: The batch capacity to validate against (prevents buffer overflow).
955+ /// - Returns: `true` if the model has an encoder (for position tracking during generation).
956+ /// - Throws: `insufficientMemory` if prompt token count exceeds batch capacity, `encoderOnlyModel` if the model cannot generate text, `encodingFailed` or `decodingFailed` on failure.
957+ private func prepareInitialBatch(
958+ batch: inout llama_batch ,
959+ promptTokens: [ llama_token ] ,
960+ model: OpaquePointer ,
961+ vocab: OpaquePointer ,
962+ context: OpaquePointer ,
963+ batchSize: UInt32
964+ ) throws -> Bool {
965+ // Validate that prompt token count doesn't exceed batch capacity to prevent buffer overflow
966+ guard promptTokens. count <= batchSize else {
967+ throw LlamaLanguageModelError . insufficientMemory
968+ }
969+
970+ let hasEncoder = llama_model_has_encoder ( model)
971+ let hasDecoder = llama_model_has_decoder ( model)
972+
973+ if hasEncoder {
974+ // For encoder models, first encode the prompt
975+ batch. n_tokens = Int32 ( promptTokens. count)
976+ for i in 0 ..< promptTokens. count {
977+ let idx = Int ( i)
978+ batch. token [ idx] = promptTokens [ idx]
979+ batch. pos [ idx] = Int32 ( i)
980+ batch. n_seq_id [ idx] = 1
981+ if let seq_ids = batch. seq_id, let seq_id = seq_ids [ idx] {
982+ seq_id [ 0 ] = 0
983+ }
984+ batch. logits [ idx] = 0
985+ }
986+
987+ guard llama_encode ( context, batch) == 0 else {
988+ throw LlamaLanguageModelError . encodingFailed
989+ }
990+
991+ if hasDecoder {
992+ // For encoder-decoder models, start decoding with decoder start token
993+ var decoderStartToken = llama_model_decoder_start_token ( model)
994+ if decoderStartToken == LLAMA_TOKEN_NULL {
995+ decoderStartToken = llama_vocab_bos ( vocab)
996+ }
997+
998+ batch. n_tokens = 1
999+ batch. token [ 0 ] = decoderStartToken
1000+ batch. pos [ 0 ] = 0
1001+ batch. n_seq_id [ 0 ] = 1
1002+ if let seq_ids = batch. seq_id, let seq_id = seq_ids [ 0 ] {
1003+ seq_id [ 0 ] = 0
1004+ }
1005+ batch. logits [ 0 ] = 1
1006+
1007+ guard llama_decode ( context, batch) == 0 else {
1008+ throw LlamaLanguageModelError . decodingFailed
1009+ }
1010+ } else {
1011+ // Encoder-only model (like BERT) - cannot generate text.
1012+ // This architectural check complements the earlier KV cache check,
1013+ // catching models by their architecture type.
1014+ throw LlamaLanguageModelError . encoderOnlyModel
1015+ }
1016+ } else {
1017+ // Standard decoder-only model (most LLMs)
1018+ batch. n_tokens = Int32 ( promptTokens. count)
1019+ for i in 0 ..< promptTokens. count {
1020+ let idx = Int ( i)
1021+ batch. token [ idx] = promptTokens [ idx]
1022+ batch. pos [ idx] = Int32 ( i)
1023+ batch. n_seq_id [ idx] = 1
1024+ if let seq_ids = batch. seq_id, let seq_id = seq_ids [ idx] {
1025+ seq_id [ 0 ] = 0
1026+ }
1027+ batch. logits [ idx] = 0
1028+ }
1029+
1030+ if batch. n_tokens > 0 {
1031+ batch. logits [ Int ( batch. n_tokens) - 1 ] = 1
1032+ }
1033+
1034+ guard llama_decode ( context, batch) == 0 else {
1035+ throw LlamaLanguageModelError . decodingFailed
1036+ }
1037+ }
1038+
1039+ return hasEncoder
1040+ }
1041+
9481042 private func formatPrompt( for session: LanguageModelSession ) throws -> String {
9491043 guard let model = self . model else {
9501044 throw LlamaLanguageModelError . modelLoadFailed
@@ -1110,6 +1204,7 @@ import Foundation
11101204 case invalidModelPath
11111205 case insufficientMemory
11121206 case unsupportedFeature
1207+ case encoderOnlyModel
11131208
11141209 public var errorDescription : String ? {
11151210 switch self {
@@ -1129,6 +1224,8 @@ import Foundation
11291224 return " Insufficient memory for operation "
11301225 case . unsupportedFeature:
11311226 return " This LlamaLanguageModel does not support image segments "
1227+ case . encoderOnlyModel:
1228+ return " This model is encoder-only (e.g., BERT) and cannot generate text "
11321229 }
11331230 }
11341231 }
0 commit comments