@@ -23,7 +23,8 @@ struct LLMAttrType
2323
2424 // std::string template_prefill_filename_axmodel = "minicpmv/prefill_axmodel/minicpm_p96_l%d.axmodel";
2525 // int prefill_axmodel_num = 40;
26- int prefill_token_num = 96 ; // auto calc
26+ int prefill_token_num = 128 ; // auto calc
27+ int prefill_max_token_num = 512 ;
2728
2829 std::string filename_post_axmodel = " tinyllama-int8/tinyllama_post.axmodel" ;
2930
@@ -76,7 +77,7 @@ class LLM
7677 std::vector<LLMLayer> llama_layers;
7778 ax_runner_ax650 llama_post;
7879
79- int prefill_grpid = 1 ;
80+ // int prefill_grpid = 1;
8081 int decode_grpid = 0 ;
8182
8283 // ax_runner_ax650 vpm_resampler;
@@ -235,7 +236,7 @@ class LLM
235236 ALOGE (" init axmodel(%s) failed" , layer.filename .c_str ());
236237 }
237238 }
238-
239+ printf ( " \n " );
239240 {
240241 _attr.max_token_len = llama_layers[0 ].layer .get_input (" mask" ).nSize / sizeof (unsigned short ) - 1 ;
241242 ALOGI (" max_token_len : %d" , _attr.max_token_len );
@@ -250,8 +251,10 @@ class LLM
250251 return false ;
251252 }
252253
253- _attr.prefill_token_num = llama_layers[0 ].layer .get_input (prefill_grpid , " indices" ).vShape [1 ];
254+ _attr.prefill_token_num = llama_layers[0 ].layer .get_input (1 , " indices" ).vShape [1 ];
254255 ALOGI (" prefill_token_num : %d" , _attr.prefill_token_num );
256+ _attr.prefill_max_token_num = llama_layers[0 ].layer .get_input (llama_layers[0 ].layer .get_num_input_groups () - 1 , " mask" ).vShape [2 ];
257+ ALOGI (" prefill_max_token_num : %d" , _attr.prefill_max_token_num );
255258 }
256259 if (attr.b_dynamic_load_axmodel_layer )
257260 {
@@ -298,9 +301,9 @@ class LLM
298301 int Encode (std::vector<unsigned short > &out_embed, std::string prompt = " What is in the image?" )
299302 {
300303 std::vector<int > input_ids = tokenizer->Encode (prompt, true );
301- if (input_ids.size () > _attr.prefill_token_num )
304+ if (input_ids.size () > _attr.prefill_max_token_num )
302305 {
303- ALOGE (" input_ids(%d) > prefill_token_num (%d)" , input_ids.size (), _attr.prefill_token_num );
306+ ALOGE (" input_ids(%d) > prefill_max_token_num (%d)" , input_ids.size (), _attr.prefill_max_token_num );
304307 return -1 ;
305308 }
306309 out_embed.resize (input_ids.size () * _attr.tokens_embed_size );
@@ -327,23 +330,50 @@ class LLM
327330 b_stop = false ;
328331 std::string final_out;
329332
333+ int input_embed_num = test_embed.size () / _attr.tokens_embed_size ;
334+ ALOGI (" input token num : %d" , input_embed_num);
335+ int prefill_split_num = ceil ((double )input_embed_num / _attr.prefill_token_num );
336+ // ALOGI("prefill_split_num : %d", prefill_split_num);
337+
330338 bfloat16 bf16 = -65536 .f ;
331339 std::vector<unsigned short > mask (_attr.kv_cache_num + 1 , bf16 .data );
332- std::vector<unsigned short > mask_p (_attr.prefill_token_num * _attr.prefill_token_num , bf16 .data );
340+ std::vector<std::vector<unsigned short >> mask_p (prefill_split_num);
341+ std::vector<unsigned short > embed (_attr.tokens_embed_size , 0 );
333342
334- for (size_t i = 0 ; i < _attr.prefill_token_num ; i++)
343+ // for (size_t i = 0; i < _attr.prefill_token_num; i++)
344+ // {
345+ // for (size_t j = 0; j < i + 1; j++)
346+ // {
347+ // mask_p[i * _attr.prefill_token_num + j] = 0;
348+ // }
349+ // }
350+ for (size_t p = 0 ; p < prefill_split_num; p++)
335351 {
336- for (size_t j = 0 ; j < i + 1 ; j++)
352+ std::vector<unsigned short > &mask_tmp = mask_p[p];
353+ mask_tmp.resize ((p + 1 ) * _attr.prefill_token_num * _attr.prefill_token_num , bf16 .data );
354+
355+ size_t i = 0 ;
356+ for (size_t t = p * _attr.prefill_token_num ; t < (p + 1 ) * _attr.prefill_token_num ; t++)
337357 {
338- mask_p[i * _attr.prefill_token_num + j] = 0 ;
358+ if (t < input_embed_num)
359+ {
360+ for (size_t j = 0 ; j < p * _attr.prefill_token_num + i + 1 ; j++)
361+ mask_tmp[i * ((p + 1 ) * _attr.prefill_token_num ) + j] = 0 ;
362+ }
363+ i++;
339364 }
365+ // char path[128];
366+ // sprintf(path, "mask_p_%d.bin", p);
367+ // FILE *fp = fopen(path, "wb");
368+ // fwrite(mask_tmp.data(), sizeof(unsigned short), mask_tmp.size(), fp);
369+ // fclose(fp);
340370 }
341371
342372 std::vector<int > cached_token;
343373 std::vector<int > token_ids;
344374 // std::vector<int> token_ids = tokenizer->Encode(input_str);
345375 // int len_of_input = token_ids.size();
346- int input_embed_num = test_embed. size () / _attr. tokens_embed_size ;
376+
347377 // ALOGI("input_embed_num(%d)", input_embed_num);
348378
349379 mask[_attr.kv_cache_num ] = 0 ;
@@ -355,70 +385,111 @@ class LLM
355385 timer ttft_timer;
356386 ttft_timer.start ();
357387
358- for (unsigned int m = 0 ; m < _attr. axmodel_num ; m ++)
388+ for (size_t p = 0 ; p < prefill_split_num; p ++)
359389 {
360390 if (b_stop)
361391 {
362392 break ;
363393 }
364394
365- auto &layer = llama_layers[m];
366- auto &layer_llama = llama_layers[m];
395+ std::vector<unsigned short > &mask_tmp = mask_p[p];
396+ std::vector<unsigned short > embed_tmp (_attr.prefill_token_num * _attr.tokens_embed_size , 0 );
397+ if (p == (prefill_split_num - 1 ))
398+ {
399+ memcpy (embed_tmp.data (), test_embed.data () + p * _attr.prefill_token_num * _attr.tokens_embed_size , (input_embed_num - p * _attr.prefill_token_num ) * _attr.tokens_embed_size * sizeof (unsigned short ));
400+ }
401+ else
402+ {
403+ memcpy (embed_tmp.data (), test_embed.data () + p * _attr.prefill_token_num * _attr.tokens_embed_size , _attr.prefill_token_num * _attr.tokens_embed_size * sizeof (unsigned short ));
404+ }
405+ int prefill_grpid = p + 1 ;
367406
368- if ( _attr.b_dynamic_load_axmodel_layer )
407+ for ( unsigned int m = 0 ; m < _attr.axmodel_num ; m++ )
369408 {
370- int ret;
371- if (_attr.b_use_mmap_load_layer )
409+ if (b_stop)
372410 {
373- ret = layer. layer . init (( char *)layer. layer_buffer . data (), layer. layer_buffer . size ()) ;
411+ break ;
374412 }
375- else
413+
414+ auto &layer = llama_layers[m];
415+ auto &layer_llama = llama_layers[m];
416+
417+ if (_attr.b_dynamic_load_axmodel_layer )
376418 {
377- ret = layer.layer .init (layer.layer_buffer_vec .data (), layer.layer_buffer_vec .size ());
419+ int ret;
420+ if (_attr.b_use_mmap_load_layer )
421+ {
422+ ret = layer.layer .init ((char *)layer.layer_buffer .data (), layer.layer_buffer .size ());
423+ }
424+ else
425+ {
426+ ret = layer.layer .init (layer.layer_buffer_vec .data (), layer.layer_buffer_vec .size ());
427+ }
428+ if (ret != 0 )
429+ {
430+ ALOGE (" init axmodel(%s) failed" , layer.filename .c_str ());
431+ }
378432 }
379- if (ret != 0 )
433+
434+ auto &input_indices = layer.layer .get_input (prefill_grpid, " indices" );
435+ unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr ;
436+
437+ for (unsigned int i = 0 ; i < _attr.prefill_token_num ; i++)
380438 {
381- ALOGE ( " init axmodel(%s) failed " , layer. filename . c_str ()) ;
439+ input_indices_ptr[i] = p * _attr. prefill_token_num + i ;
382440 }
383- }
384441
385- auto &input_indices = layer.layer .get_input (prefill_grpid, " indices" );
386- unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr ;
387- for (unsigned int i = 0 ; i < input_embed_num; i++)
388- {
389- input_indices_ptr[i] = i;
390- }
442+ if (p > 0 )
443+ {
444+ auto &input_prefill_k_cache = layer.layer .get_input (prefill_grpid, " K_cache" );
445+ auto &input_prefill_v_cache = layer.layer .get_input (prefill_grpid, " V_cache" );
446+ for (size_t i = 0 ; i < p; i++)
447+ {
448+ auto &output_k_cache = layer.layer .get_output (i + 1 , " K_cache_out" );
449+ memcpy ((unsigned short *)input_prefill_k_cache.pVirAddr + i * _attr.prefill_token_num * _attr.kv_cache_size ,
450+ output_k_cache.pVirAddr ,
451+ sizeof (unsigned short ) * _attr.prefill_token_num * _attr.kv_cache_size );
452+
453+ auto &output_v_cache = layer.layer .get_output (i + 1 , " V_cache_out" );
454+ memcpy ((unsigned short *)input_prefill_v_cache.pVirAddr + i * _attr.prefill_token_num * _attr.kv_cache_size ,
455+ output_v_cache.pVirAddr ,
456+ sizeof (unsigned short ) * _attr.prefill_token_num * _attr.kv_cache_size );
457+ }
458+ }
391459
392- auto &input_mask = layer.layer .get_input (prefill_grpid, " mask" );
393- memcpy (input_mask.pVirAddr , mask_p .data (), mask_p .size () * sizeof (unsigned short ));
460+ auto &input_mask = layer.layer .get_input (prefill_grpid, " mask" );
461+ memcpy (input_mask.pVirAddr , mask_tmp .data (), mask_tmp .size () * sizeof (unsigned short ));
394462
395- auto &input_input = layer.layer .get_input (prefill_grpid, " input" );
396- memcpy (input_input.pVirAddr , test_embed.data (), test_embed.size () * sizeof (unsigned short ));
397- if (m == 0 )
398- {
399- test_embed.resize (_attr.prefill_token_num * _attr.tokens_embed_size );
400- }
463+ auto &input_input = layer.layer .get_input (prefill_grpid, " input" );
464+ memcpy (input_input.pVirAddr , embed_tmp.data (), embed_tmp.size () * sizeof (unsigned short ));
401465
402- layer.layer .inference (prefill_grpid);
466+ layer.layer .inference (prefill_grpid);
403467
404- auto &output_k_cache = layer.layer .get_output (prefill_grpid, " K_cache_out" );
405- AX_SYS_MinvalidateCache (output_k_cache.phyAddr , output_k_cache.pVirAddr , output_k_cache.nSize );
406- auto &input_k_cache = layer_llama.layer .get_input (decode_grpid, " K_cache" );
407- memcpy (input_k_cache.pVirAddr , output_k_cache.pVirAddr , sizeof (unsigned short ) * _attr.prefill_token_num * _attr.kv_cache_size );
468+ auto &output_k_cache = layer.layer .get_output (prefill_grpid, " K_cache_out" );
469+ AX_SYS_MinvalidateCache (output_k_cache.phyAddr , output_k_cache.pVirAddr , output_k_cache.nSize );
470+ auto &input_k_cache = layer_llama.layer .get_input (decode_grpid, " K_cache" );
471+ memcpy (( unsigned short *) input_k_cache.pVirAddr + p * _attr. prefill_token_num * _attr. kv_cache_size , output_k_cache.pVirAddr , sizeof (unsigned short ) * _attr.prefill_token_num * _attr.kv_cache_size );
408472
409- auto &output_v_cache = layer.layer .get_output (prefill_grpid, " V_cache_out" );
410- AX_SYS_MinvalidateCache (output_v_cache.phyAddr , output_v_cache.pVirAddr , output_v_cache.nSize );
411- auto &input_v_cache = layer_llama.layer .get_input (decode_grpid, " V_cache" );
412- memcpy (input_v_cache.pVirAddr , output_v_cache.pVirAddr , sizeof (unsigned short ) * _attr.prefill_token_num * _attr.kv_cache_size );
473+ auto &output_v_cache = layer.layer .get_output (prefill_grpid, " V_cache_out" );
474+ AX_SYS_MinvalidateCache (output_v_cache.phyAddr , output_v_cache.pVirAddr , output_v_cache.nSize );
475+ auto &input_v_cache = layer_llama.layer .get_input (decode_grpid, " V_cache" );
476+ memcpy (( unsigned short *) input_v_cache.pVirAddr + p * _attr. prefill_token_num * _attr. kv_cache_size , output_v_cache.pVirAddr , sizeof (unsigned short ) * _attr.prefill_token_num * _attr.kv_cache_size );
413477
414- auto &output = layer.layer .get_output (prefill_grpid, " output" );
415- AX_SYS_MinvalidateCache (output.phyAddr , output.pVirAddr , output.nSize );
416- memcpy (test_embed.data (), output.pVirAddr , test_embed.size () * sizeof (unsigned short ));
417- if (_attr.b_dynamic_load_axmodel_layer )
478+ auto &output = layer.layer .get_output (prefill_grpid, " output" );
479+ AX_SYS_MinvalidateCache (output.phyAddr , output.pVirAddr , output.nSize );
480+ memcpy (embed_tmp.data (), output.pVirAddr , embed_tmp.size () * sizeof (unsigned short ));
481+ if (_attr.b_dynamic_load_axmodel_layer )
482+ {
483+ layer.layer .deinit ();
484+ }
485+ // ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
486+ }
487+ if (p == (prefill_split_num - 1 ))
418488 {
419- layer.layer .deinit ();
489+ memcpy (embed.data (),
490+ embed_tmp.data () + (input_embed_num - p * _attr.prefill_token_num - 1 ) * _attr.tokens_embed_size ,
491+ _attr.tokens_embed_size * sizeof (unsigned short ));
420492 }
421- // ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
422493 }
423494
424495 // ALOGI("prefill time cost: %.2f s", t_cost.cost() / 1000);
@@ -433,11 +504,10 @@ class LLM
433504
434505 int next_token = -1 ;
435506 t_cqdm cqdm = create_cqdm (_attr.max_token_len , 32 );
436- std::vector<unsigned short > embed (_attr.tokens_embed_size , 0 );
437507
438- memcpy (embed.data (),
439- test_embed.data () + (input_embed_num - 1 ) * _attr.tokens_embed_size ,
440- _attr.tokens_embed_size * sizeof (unsigned short ));
508+ // memcpy(embed.data(),
509+ // test_embed.data() + (input_embed_num - 1) * _attr.tokens_embed_size,
510+ // _attr.tokens_embed_size * sizeof(unsigned short));
441511
442512 {
443513
0 commit comments