@@ -57,6 +57,13 @@ int main(int argc, char ** argv) {
5757 return 1 ;
5858 }
5959
60+ const llama_vocab * vocab = llama_model_get_vocab (model);
61+ const int32_t n_vocab = llama_vocab_n_tokens (vocab);
62+
63+ const auto get_token_rand = [n_vocab]() -> llama_token {
64+ return std::rand () % n_vocab;
65+ };
66+
6067 auto * mem = llama_get_memory (ctx);
6168
6269 const int32_t n_kv_max = llama_n_ctx (ctx);
@@ -93,7 +100,7 @@ int main(int argc, char ** argv) {
93100 // warm up
94101 {
95102 for (int i = 0 ; i < 16 ; ++i) {
96- common_batch_add (batch, 0 , i, { 0 }, false );
103+ common_batch_add (batch, get_token_rand () , i, { 0 }, false );
97104 }
98105
99106 if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
@@ -127,7 +134,7 @@ int main(int argc, char ** argv) {
127134
128135 for (int j = 0 ; j < (is_pp_shared ? 1 : pl); ++j) {
129136 for (int i = 0 ; i < pp; ++i) {
130- common_batch_add (batch, 0 , i, { j }, i == pp - 1 );
137+ common_batch_add (batch, get_token_rand () , i, { j }, i == pp - 1 );
131138 }
132139 }
133140
@@ -154,7 +161,7 @@ int main(int argc, char ** argv) {
154161 common_batch_clear (batch);
155162
156163 for (int j = 0 ; j < pl; ++j) {
157- common_batch_add (batch, 0 , pp + i, { j }, true );
164+ common_batch_add (batch, get_token_rand () , pp + i, { j }, true );
158165 }
159166
160167 if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
0 commit comments