@@ -1089,6 +1089,8 @@ int main(int argc, char ** argv) {
10891089 std::set<llama_seq_id> seq_ids_in_batch;
10901090 std::vector<llama_pos> seq_id_n_past (n_seq_max, 0 );
10911091
1092+ float max_err = 0 .0f ;
1093+
10921094 // start filling the batch with prompts
10931095 while (std::any_of (seq_id_n_past.begin (), seq_id_n_past.end (),
10941096 [](llama_pos p) { return p < n_seq_len; })) {
@@ -1119,6 +1121,7 @@ int main(int argc, char ** argv) {
11191121 fprintf (stderr, " Error for seq_id %i is %f at n_past=%i\n " , seq_id, err, seq_id_n_past[seq_id]);
11201122 valid[seq_id] = false ;
11211123 }
1124+ max_err = std::max (err, max_err);
11221125 }
11231126
11241127 common_batch_clear (batch);
@@ -1140,10 +1143,11 @@ int main(int argc, char ** argv) {
11401143 " Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: " ,
11411144 variant.name .c_str (), shuffle, n_seq_max, n_ctx, n_ubatch);
11421145 if (std::all_of (valid.begin (), valid.end (), [](bool v) { return v; })) {
1143- fprintf (stdout, " \033 [1;32mOK\033 [0m\n " );
1146+ fprintf (stdout, " \033 [1;32mOK\033 [0m (max err: %.2g) \n " , max_err );
11441147 } else {
1145- fprintf (stdout, " (%zu%%) \033 [1;31mFAILED\033 [0m\n " ,
1146- std::count_if (valid.begin (), valid.end (), [](bool v) { return v == false ; }) * 100 / valid.size ());
1148+ fprintf (stdout, " (%zu%%) \033 [1;31mFAILED\033 [0m (max err: %.4g)\n " ,
1149+ std::count_if (valid.begin (), valid.end (), [](bool v) { return v == false ; }) * 100 / valid.size (),
1150+ max_err);
11471151 // cleanup and exit on first failure
11481152 llama_free (ctx);
11491153 llama_model_free (model);
0 commit comments