Skip to content

Commit 5ed3bcc

Browse files
authored
Merge pull request #7 from ZHEQIUSHUI/prefill_token_512
fix add 512 token prefill
2 parents d94638d + c41b3ea commit 5ed3bcc

File tree

2 files changed

+129
-56
lines changed

2 files changed

+129
-56
lines changed

src/runner/LLM.hpp

Lines changed: 126 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ struct LLMAttrType
2323

2424
// std::string template_prefill_filename_axmodel = "minicpmv/prefill_axmodel/minicpm_p96_l%d.axmodel";
2525
// int prefill_axmodel_num = 40;
26-
int prefill_token_num = 96; // auto calc
26+
int prefill_token_num = 128; // auto calc
27+
int prefill_max_token_num = 512;
2728

2829
std::string filename_post_axmodel = "tinyllama-int8/tinyllama_post.axmodel";
2930

@@ -76,7 +77,7 @@ class LLM
7677
std::vector<LLMLayer> llama_layers;
7778
ax_runner_ax650 llama_post;
7879

79-
int prefill_grpid = 1;
80+
// int prefill_grpid = 1;
8081
int decode_grpid = 0;
8182

8283
// ax_runner_ax650 vpm_resampler;
@@ -235,7 +236,7 @@ class LLM
235236
ALOGE("init axmodel(%s) failed", layer.filename.c_str());
236237
}
237238
}
238-
239+
printf("\n");
239240
{
240241
_attr.max_token_len = llama_layers[0].layer.get_input("mask").nSize / sizeof(unsigned short) - 1;
241242
ALOGI("max_token_len : %d", _attr.max_token_len);
@@ -250,8 +251,10 @@ class LLM
250251
return false;
251252
}
252253

253-
_attr.prefill_token_num = llama_layers[0].layer.get_input(prefill_grpid, "indices").vShape[1];
254+
_attr.prefill_token_num = llama_layers[0].layer.get_input(1, "indices").vShape[1];
254255
ALOGI("prefill_token_num : %d", _attr.prefill_token_num);
256+
_attr.prefill_max_token_num = llama_layers[0].layer.get_input(llama_layers[0].layer.get_num_input_groups() - 1, "mask").vShape[2];
257+
ALOGI("prefill_max_token_num : %d", _attr.prefill_max_token_num);
255258
}
256259
if (attr.b_dynamic_load_axmodel_layer)
257260
{
@@ -298,9 +301,9 @@ class LLM
298301
int Encode(std::vector<unsigned short> &out_embed, std::string prompt = "What is in the image?")
299302
{
300303
std::vector<int> input_ids = tokenizer->Encode(prompt, true);
301-
if (input_ids.size() > _attr.prefill_token_num)
304+
if (input_ids.size() > _attr.prefill_max_token_num)
302305
{
303-
ALOGE("input_ids(%d) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num);
306+
ALOGE("input_ids(%d) > prefill_max_token_num(%d)", input_ids.size(), _attr.prefill_max_token_num);
304307
return -1;
305308
}
306309
out_embed.resize(input_ids.size() * _attr.tokens_embed_size);
@@ -327,23 +330,50 @@ class LLM
327330
b_stop = false;
328331
std::string final_out;
329332

333+
int input_embed_num = test_embed.size() / _attr.tokens_embed_size;
334+
ALOGI("input token num : %d", input_embed_num);
335+
int prefill_split_num = ceil((double)input_embed_num / _attr.prefill_token_num);
336+
// ALOGI("prefill_split_num : %d", prefill_split_num);
337+
330338
bfloat16 bf16 = -65536.f;
331339
std::vector<unsigned short> mask(_attr.kv_cache_num + 1, bf16.data);
332-
std::vector<unsigned short> mask_p(_attr.prefill_token_num * _attr.prefill_token_num, bf16.data);
340+
std::vector<std::vector<unsigned short>> mask_p(prefill_split_num);
341+
std::vector<unsigned short> embed(_attr.tokens_embed_size, 0);
333342

334-
for (size_t i = 0; i < _attr.prefill_token_num; i++)
343+
// for (size_t i = 0; i < _attr.prefill_token_num; i++)
344+
// {
345+
// for (size_t j = 0; j < i + 1; j++)
346+
// {
347+
// mask_p[i * _attr.prefill_token_num + j] = 0;
348+
// }
349+
// }
350+
for (size_t p = 0; p < prefill_split_num; p++)
335351
{
336-
for (size_t j = 0; j < i + 1; j++)
352+
std::vector<unsigned short> &mask_tmp = mask_p[p];
353+
mask_tmp.resize((p + 1) * _attr.prefill_token_num * _attr.prefill_token_num, bf16.data);
354+
355+
size_t i = 0;
356+
for (size_t t = p * _attr.prefill_token_num; t < (p + 1) * _attr.prefill_token_num; t++)
337357
{
338-
mask_p[i * _attr.prefill_token_num + j] = 0;
358+
if (t < input_embed_num)
359+
{
360+
for (size_t j = 0; j < p * _attr.prefill_token_num + i + 1; j++)
361+
mask_tmp[i * ((p + 1) * _attr.prefill_token_num) + j] = 0;
362+
}
363+
i++;
339364
}
365+
// char path[128];
366+
// sprintf(path, "mask_p_%d.bin", p);
367+
// FILE *fp = fopen(path, "wb");
368+
// fwrite(mask_tmp.data(), sizeof(unsigned short), mask_tmp.size(), fp);
369+
// fclose(fp);
340370
}
341371

342372
std::vector<int> cached_token;
343373
std::vector<int> token_ids;
344374
// std::vector<int> token_ids = tokenizer->Encode(input_str);
345375
// int len_of_input = token_ids.size();
346-
int input_embed_num = test_embed.size() / _attr.tokens_embed_size;
376+
347377
// ALOGI("input_embed_num(%d)", input_embed_num);
348378

349379
mask[_attr.kv_cache_num] = 0;
@@ -355,70 +385,111 @@ class LLM
355385
timer ttft_timer;
356386
ttft_timer.start();
357387

358-
for (unsigned int m = 0; m < _attr.axmodel_num; m++)
388+
for (size_t p = 0; p < prefill_split_num; p++)
359389
{
360390
if (b_stop)
361391
{
362392
break;
363393
}
364394

365-
auto &layer = llama_layers[m];
366-
auto &layer_llama = llama_layers[m];
395+
std::vector<unsigned short> &mask_tmp = mask_p[p];
396+
std::vector<unsigned short> embed_tmp(_attr.prefill_token_num * _attr.tokens_embed_size, 0);
397+
if (p == (prefill_split_num - 1))
398+
{
399+
memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, (input_embed_num - p * _attr.prefill_token_num) * _attr.tokens_embed_size * sizeof(unsigned short));
400+
}
401+
else
402+
{
403+
memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, _attr.prefill_token_num * _attr.tokens_embed_size * sizeof(unsigned short));
404+
}
405+
int prefill_grpid = p + 1;
367406

368-
if (_attr.b_dynamic_load_axmodel_layer)
407+
for (unsigned int m = 0; m < _attr.axmodel_num; m++)
369408
{
370-
int ret;
371-
if (_attr.b_use_mmap_load_layer)
409+
if (b_stop)
372410
{
373-
ret = layer.layer.init((char *)layer.layer_buffer.data(), layer.layer_buffer.size());
411+
break;
374412
}
375-
else
413+
414+
auto &layer = llama_layers[m];
415+
auto &layer_llama = llama_layers[m];
416+
417+
if (_attr.b_dynamic_load_axmodel_layer)
376418
{
377-
ret = layer.layer.init(layer.layer_buffer_vec.data(), layer.layer_buffer_vec.size());
419+
int ret;
420+
if (_attr.b_use_mmap_load_layer)
421+
{
422+
ret = layer.layer.init((char *)layer.layer_buffer.data(), layer.layer_buffer.size());
423+
}
424+
else
425+
{
426+
ret = layer.layer.init(layer.layer_buffer_vec.data(), layer.layer_buffer_vec.size());
427+
}
428+
if (ret != 0)
429+
{
430+
ALOGE("init axmodel(%s) failed", layer.filename.c_str());
431+
}
378432
}
379-
if (ret != 0)
433+
434+
auto &input_indices = layer.layer.get_input(prefill_grpid, "indices");
435+
unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr;
436+
437+
for (unsigned int i = 0; i < _attr.prefill_token_num; i++)
380438
{
381-
ALOGE("init axmodel(%s) failed", layer.filename.c_str());
439+
input_indices_ptr[i] = p * _attr.prefill_token_num + i;
382440
}
383-
}
384441

385-
auto &input_indices = layer.layer.get_input(prefill_grpid, "indices");
386-
unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr;
387-
for (unsigned int i = 0; i < input_embed_num; i++)
388-
{
389-
input_indices_ptr[i] = i;
390-
}
442+
if (p > 0)
443+
{
444+
auto &input_prefill_k_cache = layer.layer.get_input(prefill_grpid, "K_cache");
445+
auto &input_prefill_v_cache = layer.layer.get_input(prefill_grpid, "V_cache");
446+
for (size_t i = 0; i < p; i++)
447+
{
448+
auto &output_k_cache = layer.layer.get_output(i + 1, "K_cache_out");
449+
memcpy((unsigned short *)input_prefill_k_cache.pVirAddr + i * _attr.prefill_token_num * _attr.kv_cache_size,
450+
output_k_cache.pVirAddr,
451+
sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
452+
453+
auto &output_v_cache = layer.layer.get_output(i + 1, "V_cache_out");
454+
memcpy((unsigned short *)input_prefill_v_cache.pVirAddr + i * _attr.prefill_token_num * _attr.kv_cache_size,
455+
output_v_cache.pVirAddr,
456+
sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
457+
}
458+
}
391459

392-
auto &input_mask = layer.layer.get_input(prefill_grpid, "mask");
393-
memcpy(input_mask.pVirAddr, mask_p.data(), mask_p.size() * sizeof(unsigned short));
460+
auto &input_mask = layer.layer.get_input(prefill_grpid, "mask");
461+
memcpy(input_mask.pVirAddr, mask_tmp.data(), mask_tmp.size() * sizeof(unsigned short));
394462

395-
auto &input_input = layer.layer.get_input(prefill_grpid, "input");
396-
memcpy(input_input.pVirAddr, test_embed.data(), test_embed.size() * sizeof(unsigned short));
397-
if (m == 0)
398-
{
399-
test_embed.resize(_attr.prefill_token_num * _attr.tokens_embed_size);
400-
}
463+
auto &input_input = layer.layer.get_input(prefill_grpid, "input");
464+
memcpy(input_input.pVirAddr, embed_tmp.data(), embed_tmp.size() * sizeof(unsigned short));
401465

402-
layer.layer.inference(prefill_grpid);
466+
layer.layer.inference(prefill_grpid);
403467

404-
auto &output_k_cache = layer.layer.get_output(prefill_grpid, "K_cache_out");
405-
AX_SYS_MinvalidateCache(output_k_cache.phyAddr, output_k_cache.pVirAddr, output_k_cache.nSize);
406-
auto &input_k_cache = layer_llama.layer.get_input(decode_grpid, "K_cache");
407-
memcpy(input_k_cache.pVirAddr, output_k_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
468+
auto &output_k_cache = layer.layer.get_output(prefill_grpid, "K_cache_out");
469+
AX_SYS_MinvalidateCache(output_k_cache.phyAddr, output_k_cache.pVirAddr, output_k_cache.nSize);
470+
auto &input_k_cache = layer_llama.layer.get_input(decode_grpid, "K_cache");
471+
memcpy((unsigned short *)input_k_cache.pVirAddr + p * _attr.prefill_token_num * _attr.kv_cache_size, output_k_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
408472

409-
auto &output_v_cache = layer.layer.get_output(prefill_grpid, "V_cache_out");
410-
AX_SYS_MinvalidateCache(output_v_cache.phyAddr, output_v_cache.pVirAddr, output_v_cache.nSize);
411-
auto &input_v_cache = layer_llama.layer.get_input(decode_grpid, "V_cache");
412-
memcpy(input_v_cache.pVirAddr, output_v_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
473+
auto &output_v_cache = layer.layer.get_output(prefill_grpid, "V_cache_out");
474+
AX_SYS_MinvalidateCache(output_v_cache.phyAddr, output_v_cache.pVirAddr, output_v_cache.nSize);
475+
auto &input_v_cache = layer_llama.layer.get_input(decode_grpid, "V_cache");
476+
memcpy((unsigned short *)input_v_cache.pVirAddr + p * _attr.prefill_token_num * _attr.kv_cache_size, output_v_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
413477

414-
auto &output = layer.layer.get_output(prefill_grpid, "output");
415-
AX_SYS_MinvalidateCache(output.phyAddr, output.pVirAddr, output.nSize);
416-
memcpy(test_embed.data(), output.pVirAddr, test_embed.size() * sizeof(unsigned short));
417-
if (_attr.b_dynamic_load_axmodel_layer)
478+
auto &output = layer.layer.get_output(prefill_grpid, "output");
479+
AX_SYS_MinvalidateCache(output.phyAddr, output.pVirAddr, output.nSize);
480+
memcpy(embed_tmp.data(), output.pVirAddr, embed_tmp.size() * sizeof(unsigned short));
481+
if (_attr.b_dynamic_load_axmodel_layer)
482+
{
483+
layer.layer.deinit();
484+
}
485+
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
486+
}
487+
if (p == (prefill_split_num - 1))
418488
{
419-
layer.layer.deinit();
489+
memcpy(embed.data(),
490+
embed_tmp.data() + (input_embed_num - p * _attr.prefill_token_num - 1) * _attr.tokens_embed_size,
491+
_attr.tokens_embed_size * sizeof(unsigned short));
420492
}
421-
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
422493
}
423494

424495
// ALOGI("prefill time cost: %.2f s", t_cost.cost() / 1000);
@@ -433,11 +504,10 @@ class LLM
433504

434505
int next_token = -1;
435506
t_cqdm cqdm = create_cqdm(_attr.max_token_len, 32);
436-
std::vector<unsigned short> embed(_attr.tokens_embed_size, 0);
437507

438-
memcpy(embed.data(),
439-
test_embed.data() + (input_embed_num - 1) * _attr.tokens_embed_size,
440-
_attr.tokens_embed_size * sizeof(unsigned short));
508+
// memcpy(embed.data(),
509+
// test_embed.data() + (input_embed_num - 1) * _attr.tokens_embed_size,
510+
// _attr.tokens_embed_size * sizeof(unsigned short));
441511

442512
{
443513

src/runner/ax_model_runner/ax_model_runner.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class ax_runner_base
6161
int get_num_inputs() { return minput_tensors.size(); };
6262
int get_num_outputs() { return moutput_tensors.size(); };
6363

64+
int get_num_input_groups() { return mgroup_input_tensors.size(); };
65+
int get_num_output_groups() { return mgroup_output_tensors.size(); };
66+
6467
const ax_runner_tensor_t &get_input(int idx) { return minput_tensors[idx]; }
6568
const ax_runner_tensor_t *get_inputs_ptr() { return minput_tensors.data(); }
6669
const ax_runner_tensor_t &get_input(std::string name)

0 commit comments

Comments
 (0)