@@ -57,6 +57,13 @@ int main(int argc, char ** argv) {
57
57
return 1 ;
58
58
}
59
59
60
+ const llama_vocab * vocab = llama_model_get_vocab (model);
61
+ const int32_t n_vocab = llama_vocab_n_tokens (vocab);
62
+
63
+ const auto get_token_rand = [n_vocab]() -> llama_token {
64
+ return std::rand () % n_vocab;
65
+ };
66
+
60
67
auto * mem = llama_get_memory (ctx);
61
68
62
69
const int32_t n_kv_max = llama_n_ctx (ctx);
@@ -93,7 +100,7 @@ int main(int argc, char ** argv) {
93
100
// warm up
94
101
{
95
102
for (int i = 0 ; i < 16 ; ++i) {
96
- common_batch_add (batch, 0 , i, { 0 }, false );
103
+ common_batch_add (batch, get_token_rand () , i, { 0 }, false );
97
104
}
98
105
99
106
if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
@@ -127,7 +134,7 @@ int main(int argc, char ** argv) {
127
134
128
135
for (int j = 0 ; j < (is_pp_shared ? 1 : pl); ++j) {
129
136
for (int i = 0 ; i < pp; ++i) {
130
- common_batch_add (batch, 0 , i, { j }, i == pp - 1 );
137
+ common_batch_add (batch, get_token_rand () , i, { j }, i == pp - 1 );
131
138
}
132
139
}
133
140
@@ -154,7 +161,7 @@ int main(int argc, char ** argv) {
154
161
common_batch_clear (batch);
155
162
156
163
for (int j = 0 ; j < pl; ++j) {
157
- common_batch_add (batch, 0 , pp + i, { j }, true );
164
+ common_batch_add (batch, get_token_rand () , pp + i, { j }, true );
158
165
}
159
166
160
167
if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
0 commit comments