Skip to content

Commit a17c4f7

Browse files
committed
test-model-random : add shared prompt test variant
1 parent 4e58ca4 commit a17c4f7

File tree

1 file changed

+123
-78
lines changed

1 file changed

+123
-78
lines changed

tests/test-model-random.cpp

Lines changed: 123 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -832,9 +832,11 @@ struct reference_logits {
832832
std::vector<llama_token> inputs;
833833
std::vector<float> outputs;
834834

835-
reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng) {
835+
reference_logits(llama_context * ctx, int32_t seq_len, std::mt19937 & rng,
836+
const std::vector<llama_token> & shared_prompt) {
836837
n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx)));
837838
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
839+
GGML_ASSERT(shared_prompt.size() < (size_t) (seq_len / 4));
838840
std::uniform_int_distribution<int32_t> rand_prompt_len(seq_len / 4, 3 * seq_len / 4);
839841

840842
llama_batch batch = llama_batch_init(seq_len, 0, 1);
@@ -843,7 +845,14 @@ struct reference_logits {
843845

844846
prompt_len = rand_prompt_len(rng);
845847

846-
for (int32_t i = 0; i < prompt_len; ++i) {
848+
for (int32_t i = 0; i < (int32_t) shared_prompt.size(); ++i) {
849+
const llama_token token = shared_prompt[i];
850+
inputs.push_back(token);
851+
852+
common_batch_add(batch, token, i, { 0 }, true);
853+
}
854+
855+
for (int32_t i = shared_prompt.size(); i < prompt_len; ++i) {
847856
const llama_token token = rand_token(rng);
848857
inputs.push_back(token);
849858

@@ -1065,6 +1074,7 @@ int main(int argc, char ** argv) {
10651074

10661075
// TODO: multiple sequences per token
10671076
const int32_t n_batch = 509; // prime number
1077+
const int32_t n_shared_len = 13; // prime number, shared prompt length
10681078
const int32_t n_seq_len = 127; // prime number
10691079

10701080
llama_batch batch = llama_batch_init(n_batch, 0, 1);
@@ -1092,9 +1102,21 @@ int main(int argc, char ** argv) {
10921102

10931103
GGML_ASSERT(model);
10941104

1095-
// const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
1105+
const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
10961106
// const auto n_embd = llama_model_n_embd(model);
10971107

1108+
std::vector<llama_token> shared_prompt;
1109+
// populate shared prompt
1110+
{
1111+
std::uniform_int_distribution<llama_token> rand_token(0, n_vocab - 1);
1112+
shared_prompt.reserve(n_shared_len);
1113+
1114+
for (int32_t i = 0; i < n_shared_len; ++i) {
1115+
shared_prompt.push_back(rand_token(rng));
1116+
}
1117+
}
1118+
1119+
// TODO: avoid re-creating reference outputs
10981120
for (int32_t n_seq_max : { 1, 2, 5 }) {
10991121

11001122
// TODO(later): context shift testing
@@ -1119,111 +1141,134 @@ int main(int argc, char ** argv) {
11191141

11201142
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
11211143
llama_memory_clear(mem, true);
1122-
ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng));
1144+
ref_outputs.push_back(reference_logits(ref_ctx, n_seq_len, rng, shared_prompt));
11231145
}
11241146

11251147
llama_free(ref_ctx);
11261148
}
11271149

1128-
for (bool shuffle : { false, true }) {
1150+
for (bool use_shared_prompt : { false, true }) {
1151+
for (bool shuffle : { false, true }) {
11291152

1130-
// can't really shuffle a single sequence with itself
1131-
if (shuffle && n_seq_max == 1) {
1132-
continue;
1133-
}
1153+
// can't really shuffle a single sequence with itself
1154+
if (shuffle && n_seq_max == 1) {
1155+
continue;
1156+
}
1157+
// can't really share a prompt with only one sequence
1158+
if (use_shared_prompt && n_seq_max == 1) {
1159+
continue;
1160+
}
11341161

1135-
for (int32_t n_ubatch : { 1, 2, 512 } ) {
1162+
for (int32_t n_ubatch : { 1, 2, 512 } ) {
11361163

1137-
std::vector<bool> valid(n_seq_max, true);
1164+
std::vector<bool> valid(n_seq_max, true);
11381165

1139-
llama_context_params ctx_params = llama_context_default_params();
1140-
ctx_params.n_ctx = n_ctx;
1141-
ctx_params.n_seq_max = n_seq_max;
1142-
ctx_params.n_ubatch = n_ubatch;
1143-
ctx_params.n_batch = n_batch;
1144-
// TODO: remove once F16 is fixed on ARM
1145-
ctx_params.type_k = GGML_TYPE_F32;
1146-
ctx_params.type_v = GGML_TYPE_F32;
1166+
llama_context_params ctx_params = llama_context_default_params();
1167+
ctx_params.n_ctx = n_ctx;
1168+
ctx_params.n_seq_max = n_seq_max;
1169+
ctx_params.n_ubatch = n_ubatch;
1170+
ctx_params.n_batch = n_batch;
1171+
// TODO: remove once F16 is fixed on ARM
1172+
ctx_params.type_k = GGML_TYPE_F32;
1173+
ctx_params.type_v = GGML_TYPE_F32;
11471174

1148-
llama_context * ctx = llama_init_from_model(model, ctx_params);
1175+
llama_context * ctx = llama_init_from_model(model, ctx_params);
11491176

1150-
common_batch_clear(batch);
1177+
common_batch_clear(batch);
11511178

1152-
std::set<llama_seq_id> seq_ids_in_batch;
1153-
std::vector<llama_pos> seq_id_n_past(n_seq_max, 0);
1179+
std::set<llama_seq_id> seq_ids_in_batch;
1180+
std::vector<llama_pos> seq_id_n_past(n_seq_max, 0);
11541181

1155-
float max_err = 0.0f;
1182+
float max_err = 0.0f;
11561183

1157-
fprintf(stdout,
1158-
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
1159-
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
1184+
fprintf(stdout,
1185+
"Comparing output for '%s', with shared=%i, shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
1186+
variant.name.c_str(), use_shared_prompt, shuffle, n_seq_max, n_ctx, n_ubatch);
11601187

1161-
// start filling the batch with prompts
1162-
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
1163-
[](llama_pos p) { return p < n_seq_len; })) {
1164-
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
1165-
if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) {
1166-
continue;
1167-
}
1188+
if (use_shared_prompt) {
1189+
// TODO: also test multiple distinct shared prompts in the same batch
1190+
std::vector<llama_seq_id> seq_id_group;
1191+
seq_id_group.reserve(n_seq_max);
11681192

1169-
if (batch.n_tokens < n_batch) {
1170-
const int64_t seq_len =
1171-
std::min(n_batch - batch.n_tokens,
1172-
ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
1193+
GGML_ASSERT(shared_prompt.size() < n_batch);
11731194

1174-
ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id);
1175-
seq_ids_in_batch.insert(seq_id);
1176-
seq_id_n_past[seq_id] += seq_len;
1177-
}
1178-
}
1179-
if (shuffle) {
1180-
shuffle_batch(batch, rng);
1195+
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
1196+
seq_id_group.push_back(seq_id);
1197+
seq_id_n_past[seq_id] += shared_prompt.size();
1198+
};
1199+
1200+
for (size_t i = 0; i < shared_prompt.size(); ++i) {
1201+
common_batch_add(batch, shared_prompt[i], i, seq_id_group, true);
1202+
};
11811203
}
11821204

1183-
llama_decode(ctx, batch);
1205+
// start filling the batch with prompts
1206+
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
1207+
[](llama_pos p) { return p < n_seq_len; })) {
1208+
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
1209+
if (seq_id_n_past[seq_id] >= ref_outputs[seq_id].prompt_len) {
1210+
continue;
1211+
}
1212+
1213+
if (batch.n_tokens < n_batch) {
1214+
const int64_t seq_len =
1215+
std::min(n_batch - batch.n_tokens,
1216+
ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
1217+
1218+
ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id);
1219+
seq_ids_in_batch.insert(seq_id);
1220+
seq_id_n_past[seq_id] += seq_len;
1221+
}
1222+
}
1223+
if (shuffle) {
1224+
shuffle_batch(batch, rng);
1225+
}
11841226

1185-
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
1186-
float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id);
1187-
if (!isfinite(err) || err > 1.0f / 1024.0f) {
1188-
fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]);
1189-
valid[seq_id] = false;
1227+
llama_decode(ctx, batch);
1228+
1229+
for (llama_seq_id seq_id = 0; seq_id < n_seq_max; ++seq_id) {
1230+
float err = ref_outputs[seq_id].validate_batch(ctx, batch, seq_id);
1231+
if (!isfinite(err) || err > 1.0f / 1024.0f) {
1232+
fprintf(stderr, "Error for seq_id %i is %f at n_past=%i\n", seq_id, err, seq_id_n_past[seq_id]);
1233+
valid[seq_id] = false;
1234+
}
1235+
max_err = std::max(err, max_err);
11901236
}
1191-
max_err = std::max(err, max_err);
1192-
}
11931237

1194-
common_batch_clear(batch);
1238+
common_batch_clear(batch);
11951239

1196-
GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here
1240+
GGML_ASSERT(n_seq_max <= n_batch); // not handling splitting this across batches here
11971241

1198-
// cont batching
1199-
for (llama_seq_id s : seq_ids_in_batch) {
1200-
llama_pos & pos = seq_id_n_past[s];
1201-
if (pos >= n_seq_len) {
1202-
continue;
1242+
// cont batching
1243+
for (llama_seq_id s : seq_ids_in_batch) {
1244+
llama_pos & pos = seq_id_n_past[s];
1245+
if (pos >= n_seq_len) {
1246+
continue;
1247+
}
1248+
ref_outputs[s].add_to_batch(batch, pos, 1, s);
1249+
pos += 1;
12031250
}
1204-
ref_outputs[s].add_to_batch(batch, pos, 1, s);
1205-
pos += 1;
12061251
}
1207-
}
12081252

1209-
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
1210-
fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err);
1211-
} else {
1212-
fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n",
1213-
std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(),
1214-
max_err);
1215-
// cleanup and exit on first failure
1216-
llama_free(ctx);
1217-
llama_model_free(model);
1218-
llama_batch_free(batch);
1219-
exit(1);
1220-
}
1253+
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
1254+
fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err);
1255+
} else {
1256+
fprintf(stdout, "(%zu%%) \033[1;31mFAILED\033[0m (max err: %.4g)\n",
1257+
std::count_if(valid.begin(), valid.end(), [](bool v) { return v == false; }) * 100 / valid.size(),
1258+
max_err);
1259+
// cleanup and exit on first failure
1260+
llama_free(ctx);
1261+
llama_model_free(model);
1262+
llama_batch_free(batch);
1263+
exit(1);
1264+
}
12211265

1222-
// TODO: use seq_rm, seq_cp, etc. to test if they work properly
1266+
// TODO: use seq_rm, seq_cp, etc. to test if they work properly
12231267

1224-
// TODO: test pooled embeddings
1268+
// TODO: test pooled embeddings
12251269

1226-
llama_free(ctx);
1270+
llama_free(ctx);
1271+
}
12271272
}
12281273
}
12291274
}

0 commit comments

Comments
 (0)