@@ -21,6 +21,8 @@ int main(int argc, char ** argv) {
2121 return 1 ;
2222 }
2323
24+ params.n_batch = params.n_ctx ;
25+
2426 common_init ();
2527
2628 int is_pp_shared = params.is_pp_shared ;
@@ -61,48 +63,21 @@ int main(int argc, char ** argv) {
6163
6264 llama_batch batch = llama_batch_init (n_kv_max, 0 , 1 );
6365
64- // decode in batches of ctx_params.n_batch tokens
65- auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
66- for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
67- const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
68-
69- llama_batch batch_view = {
70- n_tokens,
71- batch.token + i,
72- nullptr ,
73- batch.pos + i,
74- batch.n_seq_id + i,
75- batch.seq_id + i,
76- batch.logits + i,
77- };
78-
79- const int ret = llama_decode (ctx, batch_view);
80- if (ret != 0 ) {
81- LOG_ERR (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
82- return false ;
83- }
84-
85- llama_synchronize (ctx);
86- }
87-
88- return true ;
89- };
90-
9166 // warm up
9267 {
9368 for (int i = 0 ; i < 16 ; ++i) {
9469 common_batch_add (batch, 0 , i, { 0 }, false );
9570 }
9671
97- if (! decode_helper (ctx, batch, ctx_params. n_batch )) {
98- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
72+ if (const auto ret = llama_decode (ctx, batch)) {
73+ LOG_ERR (" %s: llama_decode() failed, ret = %d \n " , __func__, ret );
9974 return 1 ;
10075 }
10176 }
10277
10378 if (!params.batched_bench_output_jsonl ) {
10479 LOG (" \n " );
105- LOG (" %s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, params. n_batch , params.n_ubatch , params.flash_attn , params.is_pp_shared , params.n_gpu_layers , ctx_params.n_threads , ctx_params.n_threads_batch );
80+ LOG (" %s: n_kv_max = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, params.n_ubatch , params.flash_attn , params.is_pp_shared , params.n_gpu_layers , ctx_params.n_threads , ctx_params.n_threads_batch );
10681 LOG (" \n " );
10782 LOG (" |%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n " , " PP" , " TG" , " B" , " N_KV" , " T_PP s" , " S_PP t/s" , " T_TG s" , " S_TG t/s" , " T s" , " S t/s" );
10883 LOG (" |%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n " , " ------" , " ------" , " ----" , " ------" , " --------" , " --------" , " --------" , " --------" , " --------" , " --------" );
@@ -134,9 +109,11 @@ int main(int argc, char ** argv) {
134109
135110 llama_kv_self_clear (ctx);
136111
137- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
138- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
139- return 1 ;
112+ if (batch.n_tokens > 0 ) {
113+ if (const auto ret = llama_decode (ctx, batch) != 0 ) {
114+ LOG_ERR (" %s: llama_decode() failed, ret = %d\n " , __func__, ret);
115+ return 1 ;
116+ }
140117 }
141118
142119 if (is_pp_shared) {
@@ -156,8 +133,8 @@ int main(int argc, char ** argv) {
156133 common_batch_add (batch, 0 , pp + i, { j }, true );
157134 }
158135
159- if (! decode_helper (ctx, batch, ctx_params. n_batch ) ) {
160- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
136+ if (const auto ret = llama_decode (ctx, batch) != 0 ) {
137+ LOG_ERR (" %s: llama_decode() failed, ret = %d \n " , __func__, ret );
161138 return 1 ;
162139 }
163140 }
0 commit comments