Skip to content

Commit 9d873d7

Browse files
committed
test-model-random : shuffle across sequences but not within
There isn't really a use-case for fully-shuffled batches * test-model-random : use F32 as the KV cache type Temporary until F16 is fixed on ARM when using FP16_VECTOR_ARITHMETIC
1 parent 04b8f51 commit 9d873d7

File tree

1 file changed

+66
-7
lines changed

1 file changed

+66
-7
lines changed

tests/test-model-random.cpp

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88
#include <cstdint>
99
#include <cstdio>
10+
#include <queue>
1011
#include <random>
1112
#include <utility>
1213
// NOTE: the llm_arch enum is in the private API
@@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto
895896
memcpy(array, tmp.data(), ids.size() * elem_size);
896897
}
897898

899+
static std::vector<int32_t> random_merge_ids(std::vector<std::queue<int32_t>> & ids_per_seq, std::mt19937 & rng) {
900+
size_t total_size = 0;
901+
for (const auto & v : ids_per_seq) {
902+
total_size += v.size();
903+
}
904+
905+
std::vector<int32_t> ids;
906+
ids.reserve(total_size);
907+
908+
for (size_t i = 1; i <= total_size; ++i) {
909+
// need weighted random selection
910+
std::uniform_int_distribution<int32_t> rand(0, total_size - i);
911+
int32_t rand_id = rand(rng);
912+
913+
// find out in which seq set this would belong
914+
for (size_t j = 0; j < ids_per_seq.size(); ++j) {
915+
if (rand_id < (int32_t) ids_per_seq[j].size()) {
916+
ids.push_back(ids_per_seq[j].front());
917+
ids_per_seq[j].pop();
918+
break;
919+
}
920+
rand_id -= ids_per_seq[j].size();
921+
}
922+
}
923+
924+
return ids;
925+
}
926+
927+
// shuffle across sequences but not within seqences
898928
static void shuffle_batch(struct llama_batch & batch, std::mt19937 & rng) {
899-
std::vector<int32_t> ids(batch.n_tokens);
929+
std::vector<std::set<llama_seq_id>> seq_sets;
930+
std::vector<std::queue<int32_t>> ids_per_seq;
931+
900932
for (int32_t i = 0; i < batch.n_tokens; ++i) {
901-
ids[i] = i;
933+
int32_t seq_set_id = -1;
934+
for (size_t s = 0; s < seq_sets.size(); ++s) {
935+
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
936+
if (seq_sets[s].find(batch.seq_id[i][j]) != seq_sets[s].end()) {
937+
// any match, to avoid shuffling between dependent sets
938+
seq_set_id = s;
939+
break;
940+
}
941+
}
942+
}
943+
944+
if (seq_set_id < 0) {
945+
seq_sets.push_back({});
946+
ids_per_seq.push_back({});
947+
seq_set_id = seq_sets.size() - 1;
948+
}
949+
std::set<llama_seq_id> & seq_set = seq_sets[seq_set_id];
950+
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
951+
// make sure the set contains all relevant seq_ids
952+
seq_set.insert(batch.seq_id[i][j]);
953+
}
954+
955+
ids_per_seq[seq_set_id].push(i);
902956
}
903957

904-
std::shuffle(ids.begin(), ids.end(), rng);
958+
std::vector<int32_t> ids = random_merge_ids(ids_per_seq, rng);
959+
960+
GGML_ASSERT(ids.size() == (size_t) batch.n_tokens);
905961

906962
if (batch.token) {
907963
permute_from_ids((uint8_t *) batch.token, sizeof(*batch.token), ids);
@@ -991,6 +1047,8 @@ int main(int argc, char ** argv) {
9911047
ref_params.n_ubatch = 1;
9921048
ref_params.n_ctx = n_seq_len;
9931049
ref_params.n_seq_max = 1;
1050+
ref_params.type_k = GGML_TYPE_F32;
1051+
ref_params.type_v = GGML_TYPE_F32;
9941052

9951053
llama_context * ref_ctx = llama_init_from_model(model, ref_params);
9961054

@@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) {
10061064

10071065
for (bool shuffle : { false, true }) {
10081066

1009-
// skip shuffling the batch for non-recurrent models
1010-
// (simple splits don't handle shuffled batches correctly)
1011-
// FIXME: remove this
1012-
if (shuffle && !llama_model_is_recurrent(model)) {
1067+
// can't really shuffle a single sequence with itself
1068+
if (shuffle && n_seq_max == 1) {
10131069
continue;
10141070
}
10151071

@@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) {
10221078
ctx_params.n_seq_max = n_seq_max;
10231079
ctx_params.n_ubatch = n_ubatch;
10241080
ctx_params.n_batch = n_batch;
1081+
// TODO: remove once F16 is fixed on ARM
1082+
ctx_params.type_k = GGML_TYPE_F32;
1083+
ctx_params.type_v = GGML_TYPE_F32;
10251084

10261085
llama_context * ctx = llama_init_from_model(model, ctx_params);
10271086

0 commit comments

Comments
 (0)