@@ -1736,7 +1736,7 @@ struct sql_printer : public printer {
17361736 }
17371737};
17381738
1739- static void test_prompt (llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1739+ static bool test_prompt (llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
17401740 llama_set_n_threads (ctx, n_threads, n_threads);
17411741
17421742 const llama_model * model = llama_get_model (ctx);
@@ -1753,14 +1753,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
17531753 for (int i = 1 ; i < n_tokens; i++) {
17541754 tokens[i] = std::rand () % n_vocab;
17551755 }
1756- llama_decode (ctx, llama_batch_get_one (tokens.data (), n_tokens));
1756+ int res = llama_decode (ctx, llama_batch_get_one (tokens.data (), n_tokens));
1757+ if (res != 0 ) {
1758+ fprintf (stderr, " %s: failed to decode prompt batch, res = %d\n " , __func__, res);
1759+ return false ;
1760+ }
17571761 n_processed += n_tokens;
17581762 }
17591763
17601764 llama_synchronize (ctx);
1765+ return true ;
17611766}
17621767
1763- static void test_gen (llama_context * ctx, int n_gen, int n_threads) {
1768+ static bool test_gen (llama_context * ctx, int n_gen, int n_threads) {
17641769 llama_set_n_threads (ctx, n_threads, n_threads);
17651770
17661771 const llama_model * model = llama_get_model (ctx);
@@ -1770,10 +1775,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
17701775 llama_token token = llama_vocab_get_add_bos (vocab) ? llama_vocab_bos (vocab) : std::rand () % n_vocab;
17711776
17721777 for (int i = 0 ; i < n_gen; i++) {
1773- llama_decode (ctx, llama_batch_get_one (&token, 1 ));
1778+ int res = llama_decode (ctx, llama_batch_get_one (&token, 1 ));
1779+ if (res != 0 ) {
1780+ fprintf (stderr, " %s: failed to decode generation batch, res = %d\n " , __func__, res);
1781+ return false ;
1782+ }
17741783 llama_synchronize (ctx);
17751784 token = std::rand () % n_vocab;
17761785 }
1786+ return true ;
17771787}
17781788
17791789static void llama_null_log_callback (enum ggml_log_level level, const char * text, void * user_data) {
@@ -1917,13 +1927,21 @@ int main(int argc, char ** argv) {
19171927 fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup prompt run\n " , params_idx, params_count);
19181928 }
19191929 // test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1920- test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1930+ bool res = test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1931+ if (!res) {
1932+ fprintf (stderr, " %s: error: failed to run prompt warmup\n " , __func__);
1933+ exit (1 );
1934+ }
19211935 }
19221936 if (t.n_gen > 0 ) {
19231937 if (params.progress ) {
19241938 fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup generation run\n " , params_idx, params_count);
19251939 }
1926- test_gen (ctx, 1 , t.n_threads );
1940+ bool res = test_gen (ctx, 1 , t.n_threads );
1941+ if (!res) {
1942+ fprintf (stderr, " %s: error: failed to run gen warmup\n " , __func__);
1943+ exit (1 );
1944+ }
19271945 }
19281946
19291947 for (int i = 0 ; i < params.reps ; i++) {
@@ -1934,7 +1952,11 @@ int main(int argc, char ** argv) {
19341952 fprintf (stderr, " llama-bench: benchmark %d/%zu: depth run %d/%d\n " , params_idx, params_count,
19351953 i + 1 , params.reps );
19361954 }
1937- test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
1955+ bool res = test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
1956+ if (!res) {
1957+ fprintf (stderr, " %s: error: failed to run depth\n " , __func__);
1958+ exit (1 );
1959+ }
19381960 }
19391961
19401962 uint64_t t_start = get_time_ns ();
@@ -1944,14 +1966,22 @@ int main(int argc, char ** argv) {
19441966 fprintf (stderr, " llama-bench: benchmark %d/%zu: prompt run %d/%d\n " , params_idx, params_count,
19451967 i + 1 , params.reps );
19461968 }
1947- test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1969+ bool res = test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1970+ if (!res) {
1971+ fprintf (stderr, " %s: error: failed to run prompt\n " , __func__);
1972+ exit (1 );
1973+ }
19481974 }
19491975 if (t.n_gen > 0 ) {
19501976 if (params.progress ) {
19511977 fprintf (stderr, " llama-bench: benchmark %d/%zu: generation run %d/%d\n " , params_idx, params_count,
19521978 i + 1 , params.reps );
19531979 }
1954- test_gen (ctx, t.n_gen , t.n_threads );
1980+ bool res = test_gen (ctx, t.n_gen , t.n_threads );
1981+ if (!res) {
1982+ fprintf (stderr, " %s: error: failed to run gen\n " , __func__);
1983+ exit (1 );
1984+ }
19551985 }
19561986
19571987 uint64_t t_ns = get_time_ns () - t_start;
0 commit comments