@@ -1077,7 +1077,6 @@ int main(int argc, char ** argv) {
10771077 const int32_t n_shared_len = 13 ; // prime number, shared prompt length
10781078 const int32_t n_seq_len = 127 ; // prime number
10791079
1080- llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
10811080 // TODO: batch with embeddings
10821081
10831082 std::vector<model_variant> model_variants;
@@ -1119,6 +1118,8 @@ int main(int argc, char ** argv) {
11191118 // TODO: avoid re-creating reference outputs
11201119 for (int32_t n_seq_max : { 1 , 2 , 5 }) {
11211120
1121+ llama_batch batch = llama_batch_init (n_batch, 0 , n_seq_max);
1122+
11221123 // TODO(later): context shift testing
11231124 for (int32_t n_ctx : { n_seq_len * n_seq_max }) {
11241125
@@ -1195,6 +1196,7 @@ int main(int argc, char ** argv) {
11951196 for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
11961197 seq_id_group.push_back (seq_id);
11971198 seq_id_n_past[seq_id] += shared_prompt.size ();
1199+ seq_ids_in_batch.insert (seq_id);
11981200 };
11991201
12001202 for (size_t i = 0 ; i < shared_prompt.size (); ++i) {
@@ -1272,12 +1274,12 @@ int main(int argc, char ** argv) {
12721274 }
12731275 }
12741276 }
1277+
1278+ llama_batch_free (batch);
12751279 }
12761280
12771281 llama_model_free (model);
12781282 }
12791283
1280- llama_batch_free (batch);
1281-
12821284 return 0 ;
12831285}
0 commit comments