77#include < algorithm>
88#include < cstdint>
99#include < cstdio>
10+ #include < queue>
1011#include < random>
1112#include < utility>
1213// NOTE: the llm_arch enum is in the private API
@@ -895,13 +896,68 @@ static void permute_from_ids(uint8_t * array, size_t elem_size, const std::vecto
895896 memcpy (array, tmp.data (), ids.size () * elem_size);
896897}
897898
899+ static std::vector<int32_t > random_merge_ids (std::vector<std::queue<int32_t >> & ids_per_seq, std::mt19937 & rng) {
900+ size_t total_size = 0 ;
901+ for (const auto & v : ids_per_seq) {
902+ total_size += v.size ();
903+ }
904+
905+ std::vector<int32_t > ids;
906+ ids.reserve (total_size);
907+
908+ for (size_t i = 1 ; i <= total_size; ++i) {
909+ // need weighted random selection
910+ std::uniform_int_distribution<int32_t > rand (0 , total_size - i);
911+ int32_t rand_id = rand (rng);
912+
913+ // find out in which seq set this would belong
914+ for (size_t j = 0 ; j < ids_per_seq.size (); ++j) {
915+ if (rand_id < (int32_t ) ids_per_seq[j].size ()) {
916+ ids.push_back (ids_per_seq[j].front ());
917+ ids_per_seq[j].pop ();
918+ break ;
919+ }
920+ rand_id -= ids_per_seq[j].size ();
921+ }
922+ }
923+
924+ return ids;
925+ }
926+
927+ // shuffle across sequences but not within seqences
898928static void shuffle_batch (struct llama_batch & batch, std::mt19937 & rng) {
899- std::vector<int32_t > ids (batch.n_tokens );
929+ std::vector<std::set<llama_seq_id>> seq_sets;
930+ std::vector<std::queue<int32_t >> ids_per_seq;
931+
900932 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
901- ids[i] = i;
933+ int32_t seq_set_id = -1 ;
934+ for (size_t s = 0 ; s < seq_sets.size (); ++s) {
935+ for (int j = 0 ; j < batch.n_seq_id [i]; ++j) {
936+ if (seq_sets[s].find (batch.seq_id [i][j]) != seq_sets[s].end ()) {
937+ // any match, to avoid shuffling between dependent sets
938+ seq_set_id = s;
939+ break ;
940+ }
941+ }
942+ }
943+
944+ if (seq_set_id < 0 ) {
945+ seq_sets.push_back ({});
946+ ids_per_seq.push_back ({});
947+ seq_set_id = seq_sets.size () - 1 ;
948+ }
949+ std::set<llama_seq_id> & seq_set = seq_sets[seq_set_id];
950+ for (int j = 0 ; j < batch.n_seq_id [i]; ++j) {
951+ // make sure the set contains all relevant seq_ids
952+ seq_set.insert (batch.seq_id [i][j]);
953+ }
954+
955+ ids_per_seq[seq_set_id].push (i);
902956 }
903957
904- std::shuffle (ids.begin (), ids.end (), rng);
958+ std::vector<int32_t > ids = random_merge_ids (ids_per_seq, rng);
959+
960+ GGML_ASSERT (ids.size () == (size_t ) batch.n_tokens );
905961
906962 if (batch.token ) {
907963 permute_from_ids ((uint8_t *) batch.token , sizeof (*batch.token ), ids);
@@ -991,6 +1047,8 @@ int main(int argc, char ** argv) {
9911047 ref_params.n_ubatch = 1 ;
9921048 ref_params.n_ctx = n_seq_len;
9931049 ref_params.n_seq_max = 1 ;
1050+ ref_params.type_k = GGML_TYPE_F32;
1051+ ref_params.type_v = GGML_TYPE_F32;
9941052
9951053 llama_context * ref_ctx = llama_init_from_model (model, ref_params);
9961054
@@ -1006,10 +1064,8 @@ int main(int argc, char ** argv) {
10061064
10071065 for (bool shuffle : { false , true }) {
10081066
1009- // skip shuffling the batch for non-recurrent models
1010- // (simple splits don't handle shuffled batches correctly)
1011- // FIXME: remove this
1012- if (shuffle && !llama_model_is_recurrent (model)) {
1067+ // can't really shuffle a single sequence with itself
1068+ if (shuffle && n_seq_max == 1 ) {
10131069 continue ;
10141070 }
10151071
@@ -1022,6 +1078,9 @@ int main(int argc, char ** argv) {
10221078 ctx_params.n_seq_max = n_seq_max;
10231079 ctx_params.n_ubatch = n_ubatch;
10241080 ctx_params.n_batch = n_batch;
1081+ // TODO: remove once F16 is fixed on ARM
1082+ ctx_params.type_k = GGML_TYPE_F32;
1083+ ctx_params.type_v = GGML_TYPE_F32;
10251084
10261085 llama_context * ctx = llama_init_from_model (model, ctx_params);
10271086
0 commit comments