@@ -832,9 +832,11 @@ struct reference_logits {
832832 std::vector<llama_token> inputs;
833833 std::vector<float > outputs;
834834
835- reference_logits (llama_context * ctx, int32_t seq_len, std::mt19937 & rng) {
835+ reference_logits (llama_context * ctx, int32_t seq_len, std::mt19937 & rng,
836+ const std::vector<llama_token> & shared_prompt) {
836837 n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (llama_get_model (ctx)));
837838 std::uniform_int_distribution<llama_token> rand_token (0 , n_vocab - 1 );
839+ GGML_ASSERT (shared_prompt.size () < (size_t ) (seq_len / 4 ));
838840 std::uniform_int_distribution<int32_t > rand_prompt_len (seq_len / 4 , 3 * seq_len / 4 );
839841
840842 llama_batch batch = llama_batch_init (seq_len, 0 , 1 );
@@ -843,7 +845,14 @@ struct reference_logits {
843845
844846 prompt_len = rand_prompt_len (rng);
845847
846- for (int32_t i = 0 ; i < prompt_len; ++i) {
848+ for (int32_t i = 0 ; i < (int32_t ) shared_prompt.size (); ++i) {
849+ const llama_token token = shared_prompt[i];
850+ inputs.push_back (token);
851+
852+ common_batch_add (batch, token, i, { 0 }, true );
853+ }
854+
855+ for (int32_t i = shared_prompt.size (); i < prompt_len; ++i) {
847856 const llama_token token = rand_token (rng);
848857 inputs.push_back (token);
849858
@@ -1065,6 +1074,7 @@ int main(int argc, char ** argv) {
10651074
10661075 // TODO: multiple sequences per token
10671076 const int32_t n_batch = 509 ; // prime number
1077+ const int32_t n_shared_len = 13 ; // prime number, shared prompt length
10681078 const int32_t n_seq_len = 127 ; // prime number
10691079
10701080 llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
@@ -1092,9 +1102,21 @@ int main(int argc, char ** argv) {
10921102
10931103 GGML_ASSERT (model);
10941104
1095- // const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
1105+ const auto n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model));
10961106 // const auto n_embd = llama_model_n_embd(model);
10971107
1108+ std::vector<llama_token> shared_prompt;
1109+ // populate shared prompt
1110+ {
1111+ std::uniform_int_distribution<llama_token> rand_token (0 , n_vocab - 1 );
1112+ shared_prompt.reserve (n_shared_len);
1113+
1114+ for (int32_t i = 0 ; i < n_shared_len; ++i) {
1115+ shared_prompt.push_back (rand_token (rng));
1116+ }
1117+ }
1118+
1119+ // TODO: avoid re-creating reference outputs
10981120 for (int32_t n_seq_max : { 1 , 2 , 5 }) {
10991121
11001122 // TODO(later): context shift testing
@@ -1119,111 +1141,134 @@ int main(int argc, char ** argv) {
11191141
11201142 for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
11211143 llama_memory_clear (mem, true );
1122- ref_outputs.push_back (reference_logits (ref_ctx, n_seq_len, rng));
1144+ ref_outputs.push_back (reference_logits (ref_ctx, n_seq_len, rng, shared_prompt ));
11231145 }
11241146
11251147 llama_free (ref_ctx);
11261148 }
11271149
1128- for (bool shuffle : { false , true }) {
1150+ for (bool use_shared_prompt : { false , true }) {
1151+ for (bool shuffle : { false , true }) {
11291152
1130- // can't really shuffle a single sequence with itself
1131- if (shuffle && n_seq_max == 1 ) {
1132- continue ;
1133- }
1153+ // can't really shuffle a single sequence with itself
1154+ if (shuffle && n_seq_max == 1 ) {
1155+ continue ;
1156+ }
1157+ // can't really share a prompt with only one sequence
1158+ if (use_shared_prompt && n_seq_max == 1 ) {
1159+ continue ;
1160+ }
11341161
1135- for (int32_t n_ubatch : { 1 , 2 , 512 } ) {
1162+ for (int32_t n_ubatch : { 1 , 2 , 512 } ) {
11361163
1137- std::vector<bool > valid (n_seq_max, true );
1164+ std::vector<bool > valid (n_seq_max, true );
11381165
1139- llama_context_params ctx_params = llama_context_default_params ();
1140- ctx_params.n_ctx = n_ctx;
1141- ctx_params.n_seq_max = n_seq_max;
1142- ctx_params.n_ubatch = n_ubatch;
1143- ctx_params.n_batch = n_batch;
1144- // TODO: remove once F16 is fixed on ARM
1145- ctx_params.type_k = GGML_TYPE_F32;
1146- ctx_params.type_v = GGML_TYPE_F32;
1166+ llama_context_params ctx_params = llama_context_default_params ();
1167+ ctx_params.n_ctx = n_ctx;
1168+ ctx_params.n_seq_max = n_seq_max;
1169+ ctx_params.n_ubatch = n_ubatch;
1170+ ctx_params.n_batch = n_batch;
1171+ // TODO: remove once F16 is fixed on ARM
1172+ ctx_params.type_k = GGML_TYPE_F32;
1173+ ctx_params.type_v = GGML_TYPE_F32;
11471174
1148- llama_context * ctx = llama_init_from_model (model, ctx_params);
1175+ llama_context * ctx = llama_init_from_model (model, ctx_params);
11491176
1150- common_batch_clear (batch);
1177+ common_batch_clear (batch);
11511178
1152- std::set<llama_seq_id> seq_ids_in_batch;
1153- std::vector<llama_pos> seq_id_n_past (n_seq_max, 0 );
1179+ std::set<llama_seq_id> seq_ids_in_batch;
1180+ std::vector<llama_pos> seq_id_n_past (n_seq_max, 0 );
11541181
1155- float max_err = 0 .0f ;
1182+ float max_err = 0 .0f ;
11561183
1157- fprintf (stdout,
1158- " Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: " ,
1159- variant.name .c_str (), shuffle, n_seq_max, n_ctx, n_ubatch);
1184+ fprintf (stdout,
1185+ " Comparing output for '%s', with shared=%i, shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: " ,
1186+ variant.name .c_str (), use_shared_prompt , shuffle, n_seq_max, n_ctx, n_ubatch);
11601187
1161- // start filling the batch with prompts
1162- while (std::any_of (seq_id_n_past.begin (), seq_id_n_past.end (),
1163- [](llama_pos p) { return p < n_seq_len; })) {
1164- for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
1165- if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len ) {
1166- continue ;
1167- }
1188+ if (use_shared_prompt) {
1189+ // TODO: also test multiple distinct shared prompts in the same batch
1190+ std::vector<llama_seq_id> seq_id_group;
1191+ seq_id_group.reserve (n_seq_max);
11681192
1169- if (batch.n_tokens < n_batch) {
1170- const int64_t seq_len =
1171- std::min (n_batch - batch.n_tokens ,
1172- ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
1193+ GGML_ASSERT (shared_prompt.size () < n_batch);
11731194
1174- ref_outputs[seq_id].add_to_batch (batch, seq_id_n_past[seq_id], seq_len, seq_id);
1175- seq_ids_in_batch.insert (seq_id);
1176- seq_id_n_past[seq_id] += seq_len;
1177- }
1178- }
1179- if (shuffle) {
1180- shuffle_batch (batch, rng);
1195+ for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
1196+ seq_id_group.push_back (seq_id);
1197+ seq_id_n_past[seq_id] += shared_prompt.size ();
1198+ };
1199+
1200+ for (size_t i = 0 ; i < shared_prompt.size (); ++i) {
1201+ common_batch_add (batch, shared_prompt[i], i, seq_id_group, true );
1202+ };
11811203 }
11821204
1183- llama_decode (ctx, batch);
1205+ // start filling the batch with prompts
1206+ while (std::any_of (seq_id_n_past.begin (), seq_id_n_past.end (),
1207+ [](llama_pos p) { return p < n_seq_len; })) {
1208+ for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
1209+ if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len ) {
1210+ continue ;
1211+ }
1212+
1213+ if (batch.n_tokens < n_batch) {
1214+ const int64_t seq_len =
1215+ std::min (n_batch - batch.n_tokens ,
1216+ ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
1217+
1218+ ref_outputs[seq_id].add_to_batch (batch, seq_id_n_past[seq_id], seq_len, seq_id);
1219+ seq_ids_in_batch.insert (seq_id);
1220+ seq_id_n_past[seq_id] += seq_len;
1221+ }
1222+ }
1223+ if (shuffle) {
1224+ shuffle_batch (batch, rng);
1225+ }
11841226
1185- for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
1186- float err = ref_outputs[seq_id].validate_batch (ctx, batch, seq_id);
1187- if (!isfinite (err) || err > 1 .0f / 1024 .0f ) {
1188- fprintf (stderr, " Error for seq_id %i is %f at n_past=%i\n " , seq_id, err, seq_id_n_past[seq_id]);
1189- valid[seq_id] = false ;
1227+ llama_decode (ctx, batch);
1228+
1229+ for (llama_seq_id seq_id = 0 ; seq_id < n_seq_max; ++seq_id) {
1230+ float err = ref_outputs[seq_id].validate_batch (ctx, batch, seq_id);
1231+ if (!isfinite (err) || err > 1 .0f / 1024 .0f ) {
1232+ fprintf (stderr, " Error for seq_id %i is %f at n_past=%i\n " , seq_id, err, seq_id_n_past[seq_id]);
1233+ valid[seq_id] = false ;
1234+ }
1235+ max_err = std::max (err, max_err);
11901236 }
1191- max_err = std::max (err, max_err);
1192- }
11931237
1194- common_batch_clear (batch);
1238+ common_batch_clear (batch);
11951239
1196- GGML_ASSERT (n_seq_max <= n_batch); // not handling splitting this across batches here
1240+ GGML_ASSERT (n_seq_max <= n_batch); // not handling splitting this across batches here
11971241
1198- // cont batching
1199- for (llama_seq_id s : seq_ids_in_batch) {
1200- llama_pos & pos = seq_id_n_past[s];
1201- if (pos >= n_seq_len) {
1202- continue ;
1242+ // cont batching
1243+ for (llama_seq_id s : seq_ids_in_batch) {
1244+ llama_pos & pos = seq_id_n_past[s];
1245+ if (pos >= n_seq_len) {
1246+ continue ;
1247+ }
1248+ ref_outputs[s].add_to_batch (batch, pos, 1 , s);
1249+ pos += 1 ;
12031250 }
1204- ref_outputs[s].add_to_batch (batch, pos, 1 , s);
1205- pos += 1 ;
12061251 }
1207- }
12081252
1209- if (std::all_of (valid.begin (), valid.end (), [](bool v) { return v; })) {
1210- fprintf (stdout, " \033 [1;32mOK\033 [0m (max err: %.2g)\n " , max_err);
1211- } else {
1212- fprintf (stdout, " (%zu%%) \033 [1;31mFAILED\033 [0m (max err: %.4g)\n " ,
1213- std::count_if (valid.begin (), valid.end (), [](bool v) { return v == false ; }) * 100 / valid.size (),
1214- max_err);
1215- // cleanup and exit on first failure
1216- llama_free (ctx);
1217- llama_model_free (model);
1218- llama_batch_free (batch);
1219- exit (1 );
1220- }
1253+ if (std::all_of (valid.begin (), valid.end (), [](bool v) { return v; })) {
1254+ fprintf (stdout, " \033 [1;32mOK\033 [0m (max err: %.2g)\n " , max_err);
1255+ } else {
1256+ fprintf (stdout, " (%zu%%) \033 [1;31mFAILED\033 [0m (max err: %.4g)\n " ,
1257+ std::count_if (valid.begin (), valid.end (), [](bool v) { return v == false ; }) * 100 / valid.size (),
1258+ max_err);
1259+ // cleanup and exit on first failure
1260+ llama_free (ctx);
1261+ llama_model_free (model);
1262+ llama_batch_free (batch);
1263+ exit (1 );
1264+ }
12211265
1222- // TODO: use seq_rm, seq_cp, etc. to test if they work properly
1266+ // TODO: use seq_rm, seq_cp, etc. to test if they work properly
12231267
1224- // TODO: test pooled embeddings
1268+ // TODO: test pooled embeddings
12251269
1226- llama_free (ctx);
1270+ llama_free (ctx);
1271+ }
12271272 }
12281273 }
12291274 }
0 commit comments