Skip to content

Commit 274d55e

Browse files
Kush Rastogifacebook-github-bot
authored andcommitted
Adding KV to Prefill IO (pytorch#9466)
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
1 parent 5c5b84e commit 274d55e

File tree

5 files changed

+112
-36
lines changed

5 files changed

+112
-36
lines changed

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ int main(int argc, char** argv) {
7676
std::vector<char> buf;
7777
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
7878
std::ofstream fout(FLAGS_output_path.c_str());
79-
auto callback = [&](const std::string& piece) {
79+
80+
int32_t num_total_tokens = 0;
81+
82+
auto callback = [&](const std::string& piece, int32_t tokens_generated) {
83+
num_total_tokens += tokens_generated;
8084
for (const char c : piece) {
8185
buf.push_back(c);
8286
}
@@ -85,6 +89,7 @@ int main(int argc, char** argv) {
8589
for (int i = 0; i < FLAGS_num_iters; i++) {
8690
runner.generate(
8791
FLAGS_seq_len,
92+
num_total_tokens,
8893
FLAGS_prompt.c_str(),
8994
FLAGS_system_prompt.c_str(),
9095
callback);

examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,52 @@ void ShiftPointerIoMgr::prepare_prefill_io(
494494
}
495495
}
496496

497+
void ShiftPointerIoMgr::update_kv_to_prefill_io(
498+
int64_t pos,
499+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) {
500+
// update v_cache
501+
assert(pos <= 512);
502+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_in =
503+
v_cache_in_[prefill_forward_name_];
504+
for (int i = 0, v_cache_stride = head_dim_ * pos; i < v_cache_in.size();
505+
i++) {
506+
v_cache_in[i]->set_data(
507+
v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
508+
}
509+
510+
// update k_cache
511+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in =
512+
k_cache_in_[prefill_forward_name_];
513+
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_out =
514+
k_cache_out_[prefill_forward_name_];
515+
516+
size_t copied_size = pos * sizeof(uint8_t);
517+
518+
for (int i = 0, k_cache_stride = pos * sizeof(uint8_t); i < k_cache_in_.size();
519+
i++) {
520+
k_cache_in[i]->set_data(
521+
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
522+
// assume current pointer has been shifted to absolute position of `pos`
523+
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos;
524+
for (int j = head_dim_; j > -1; --j) {
525+
memcpy(
526+
ptr_in + j * prefill_cache_len_,
527+
ptr_in + j * kv_cache_len_,
528+
copied_size);
529+
}
530+
k_cache_out[i]->set_data(
531+
k_cache_out[i]->mutable_data<uint8_t>() + k_cache_stride);
532+
}
533+
534+
// Setting attention mask from context_len - prefill_ar_len - i to context_len
535+
IO* ptr = static_cast<IO*>(data_ptr_.get());
536+
for (int i = prefill_ar_len_; i < pos; i++) {
537+
for (int j = 0; j < prefill_ar_len_; j++) {
538+
ptr->prefill_attention_mask[j * context_len_ + context_len_ - prefill_ar_len_ - i] = 65535;
539+
}
540+
}
541+
}
542+
497543
void ShiftPointerIoMgr::update_prefill_to_kv_io(
498544
int64_t cur_token,
499545
int64_t pos,
@@ -664,33 +710,32 @@ void ShiftPointerIoMgr::update_prefill_io(
664710
}
665711

666712
void ShiftPointerIoMgr::fill_prefill_toks(
667-
int64_t start_pos,
713+
int64_t num_prev_tokens,
714+
int64_t prompt_pos,
668715
std::vector<uint64_t>& prompt_tokens) {
669716
IO* ptr = static_cast<IO*>(get_mutable_ptr());
670717
for (int i = 0; i < prefill_ar_len_; i++) {
671718
if (!is_bert_) {
672-
ptr->prefill_input_pos[i] = start_pos + i;
719+
ptr->prefill_input_pos[i] = num_prev_tokens + prompt_pos + i;
673720
}
674721

675-
if (start_pos + i < prompt_tokens.size()) {
722+
if (prompt_pos + i < prompt_tokens.size()) {
676723
// Support CPU 4-bit embedding, which requires int64 input.
677724
// However, for QNN embedding, only int32 input is needed.
678725
// Therefore, we need to cast to the correct type to write the data.
679726
if (use_int64_token_) {
680-
ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i];
727+
ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i];
681728
} else {
682729
int32_t* prefill_input_toks_ptr =
683730
reinterpret_cast<int32_t*>(ptr->prefill_input_toks.data());
684731
prefill_input_toks_ptr[i] =
685-
static_cast<int32_t>(prompt_tokens[start_pos + i]);
732+
static_cast<int32_t>(prompt_tokens[prompt_pos + i]);
686733
}
687734
}
688-
if (start_pos >= prefill_ar_len_) {
689-
for (int j = 0,
690-
offset = i * context_len_ +
691-
(context_len_ - prefill_ar_len_ - start_pos);
692-
j < prefill_ar_len_;
693-
++j) {
735+
if (num_prev_tokens + prompt_pos >= prefill_ar_len_) {
736+
int64_t start_offset = i * context_len_ +
737+
(context_len_ - num_prev_tokens - prompt_pos - prefill_ar_len_);
738+
for (int j = 0, offset = start_offset; j < prefill_ar_len_; ++j) {
694739
ptr->prefill_attention_mask[offset + j] = 65535;
695740
}
696741
}
@@ -1305,6 +1350,12 @@ void SmartMaskIoMgr::prepare_prefill_io(
13051350
}
13061351
}
13071352

1353+
void SmartMaskIoMgr::update_kv_to_prefill_io(
1354+
int64_t pos,
1355+
std::vector<std::vector<Tensor>>& output_tensors) {
1356+
//TODO: Fill In
1357+
}
1358+
13081359
void SmartMaskIoMgr::update_prefill_to_kv_io(
13091360
int64_t cur_token,
13101361
int64_t pos,
@@ -1396,29 +1447,30 @@ void SmartMaskIoMgr::update_prefill_io(
13961447
}
13971448

13981449
void SmartMaskIoMgr::fill_prefill_toks(
1399-
int64_t start_pos,
1450+
int64_t num_prev_tokens,
1451+
int64_t prompt_pos,
14001452
std::vector<uint64_t>& prompt_tokens) {
14011453
IO* ptr = static_cast<IO*>(get_mutable_ptr());
14021454
for (int i = 0; i < prefill_ar_len_; i++) {
14031455
if (!is_bert_) {
1404-
ptr->prefill_input_pos[i] = start_pos + i;
1456+
ptr->prefill_input_pos[i] = prompt_pos + i;
14051457
}
14061458

1407-
if (start_pos + i < prompt_tokens.size()) {
1459+
if (prompt_pos + i < prompt_tokens.size()) {
14081460
// Support CPU 4-bit embedding, which requires int64 input.
14091461
// However, for QNN embedding, only int32 input is needed.
14101462
// Therefore, we need to cast to the correct type to write the data.
14111463
if (use_int64_token_) {
1412-
ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i];
1464+
ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i];
14131465
} else {
14141466
int32_t* prefill_input_toks_ptr =
14151467
reinterpret_cast<int32_t*>(ptr->prefill_input_toks);
14161468
prefill_input_toks_ptr[i] =
1417-
static_cast<int32_t>(prompt_tokens[start_pos + i]);
1469+
static_cast<int32_t>(prompt_tokens[prompt_pos + i]);
14181470
}
14191471
}
1420-
if (start_pos >= prefill_ar_len_) {
1421-
for (int j = 0, offset = i * context_len_ + (start_pos - prefill_ar_len_);
1472+
if (prompt_pos >= prefill_ar_len_) {
1473+
for (int j = 0, offset = i * context_len_ + (prompt_pos - prefill_ar_len_);
14221474
j < prefill_ar_len_;
14231475
++j) {
14241476
ptr->prefill_attention_mask[offset + j] = 65535;

examples/qualcomm/oss_scripts/llama/runner/io_manager.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ class IoMgrBase {
4848
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
4949
methods_meta) = 0;
5050
virtual void fill_prefill_toks(
51-
int64_t start_pos,
51+
int64_t num_prev_tokens,
52+
int64_t prompt_pos,
5253
std::vector<uint64_t>& prompt_tokens) = 0;
5354
virtual void fill_kv_tok_mask(int64_t pos, int64_t cur_token) = 0;
55+
virtual void update_kv_to_prefill_io(
56+
int64_t pos,
57+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
5458
virtual void update_prefill_to_kv_io(
5559
int64_t cur_token,
5660
int64_t pos,
@@ -118,9 +122,13 @@ class ShiftPointerIoMgr : public IoMgrBase {
118122
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
119123
methods_meta) override;
120124
void fill_prefill_toks(
121-
int64_t start_pos,
125+
int64_t num_prev_tokens,
126+
int64_t prompt_pos,
122127
std::vector<uint64_t>& prompt_tokens) override;
123128
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
129+
void update_kv_to_prefill_io(
130+
int64_t pos,
131+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
124132
void update_prefill_to_kv_io(
125133
int64_t cur_token,
126134
int64_t pos,
@@ -226,9 +234,13 @@ class SmartMaskIoMgr : public IoMgrBase {
226234
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
227235
methods_meta) override;
228236
void fill_prefill_toks(
229-
int64_t start_pos,
237+
int64_t num_prev_tokens,
238+
int64_t prompt_pos,
230239
std::vector<uint64_t>& prompt_tokens) override;
231240
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
241+
void update_kv_to_prefill_io(
242+
int64_t pos,
243+
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
232244
void update_prefill_to_kv_io(
233245
int64_t cur_token,
234246
int64_t pos,

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,10 @@ void Runner::run_model_step(
276276

277277
Error Runner::generate(
278278
int32_t seq_len,
279+
int32_t num_prev_tokens,
279280
const std::string& prompt,
280281
const std::string& system_prompt,
281-
std::function<void(const std::string&)> token_callback,
282+
std::function<void(const std::string&, int32_t)> token_callback,
282283
std::function<void(const Stats&)> stats_callback) {
283284
std::unordered_map<std::string, std::vector<std::vector<Tensor>>>
284285
input_tensors, output_tensors;
@@ -327,14 +328,16 @@ Error Runner::generate(
327328
prompt_.append(
328329
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
329330
if (token_callback) {
330-
token_callback("<|begin_of_text|>");
331+
token_callback("<|begin_of_text|>", 0);
331332
}
332333
break;
333334
default:
334335
ET_CHECK_MSG(false, "unsupported llama version");
335336
break;
336337
}
337338

339+
ET_LOG(Info, "Number of Previous Tokens Prefill + Decode: %d", num_prev_tokens);
340+
338341
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
339342
tokenizers::Result<std::vector<uint64_t>> encode_res =
340343
tokenizer_->encode(prompt_, n_bos_, 0);
@@ -349,7 +352,7 @@ Error Runner::generate(
349352

350353
int64_t pos = 0, prev_token, cur_token = prompt_tokens[0];
351354
if (token_callback) {
352-
token_callback(prompt_);
355+
token_callback(prompt_, num_prompt_tokens);
353356
}
354357
auto prefill_execute = [&](const std::string& method_name) {
355358
int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_);
@@ -361,7 +364,7 @@ Error Runner::generate(
361364
num_iters);
362365

363366
for (int i = 0; i < num_iters; i++) {
364-
io_mgr_->fill_prefill_toks(pos, prompt_tokens);
367+
io_mgr_->fill_prefill_toks(num_prev_tokens, pos, prompt_tokens);
365368
run_model_step(method_name, inputs[method_name]);
366369
io_mgr_->update_prefill_io(cur_token, pos, output_tensors[method_name]);
367370
pos += prefill_ar_len_;
@@ -377,10 +380,12 @@ Error Runner::generate(
377380
auto piece_res = tokenizer_->decode(prev_token, cur_token);
378381
ET_CHECK(piece_res.ok());
379382
if (token_callback) {
380-
token_callback(piece_res.get().c_str());
383+
ET_LOG(Info, "Prefill: %s", piece_res.get().c_str());
384+
token_callback(piece_res.get().c_str(), 1);
381385
}
382386

383-
pos = num_prompt_tokens;
387+
pos = num_prev_tokens + num_prompt_tokens;
388+
ET_LOG(Info, "Pos: %ld, Prompt Tokens: %ld", pos, num_prompt_tokens);
384389
stats_.first_token_ms = time_in_ms();
385390
stats_.prompt_eval_end_ms = time_in_ms();
386391
};
@@ -394,9 +399,9 @@ Error Runner::generate(
394399

395400
// hybrid mode will check these stats_ at prefill(prefill)
396401
if (eval_mode_ == EvalMode::kKVCached) {
397-
if (pos == num_prompt_tokens) {
402+
if (pos == num_prev_tokens + num_prompt_tokens) {
398403
stats_.first_token_ms = time_in_ms();
399-
} else if (pos == num_prompt_tokens - 1) {
404+
} else if (pos == num_prev_tokens + num_prompt_tokens - 1) {
400405
stats_.prompt_eval_end_ms = time_in_ms();
401406
}
402407
}
@@ -405,15 +410,15 @@ Error Runner::generate(
405410
cur_token = logitsToToken(logits_tensor, pos);
406411
stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
407412

408-
if (pos < num_prompt_tokens - 1) {
413+
if (pos < num_prev_tokens + num_prompt_tokens - 1) {
409414
cur_token = prompt_tokens[pos + 1];
410415
}
411416
io_mgr_->update_kv_io(cur_token, ++pos, output_tensors[method_name]);
412417
auto piece_res = tokenizer_->decode(prev_token, cur_token);
413418
ET_CHECK(piece_res.ok());
414419

415420
if (token_callback && pos >= num_prompt_tokens) {
416-
token_callback(piece_res.get().c_str());
421+
token_callback(piece_res.get().c_str(), 1);
417422
}
418423

419424
if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
@@ -432,6 +437,7 @@ Error Runner::generate(
432437
io_mgr_->update_prefill_to_kv_io(
433438
cur_token, pos, output_tensors[kv_forward_name_]);
434439
kv_execute(kv_forward_name_);
440+
io_mgr_->update_kv_to_prefill_io(pos, output_tensors[prefill_forward_name_]);
435441
break;
436442
default:
437443
ET_CHECK_MSG(false, "Unsupported eval mode");
@@ -448,9 +454,9 @@ Error Runner::generate(
448454
if (stats_callback) {
449455
stats_callback(stats_);
450456
}
451-
io_mgr_->reset_io(
452-
get_methods_meta(prefill_forward_name_),
453-
get_methods_meta(kv_forward_name_));
457+
// io_mgr_->reset_io(
458+
// get_methods_meta(prefill_forward_name_),
459+
// get_methods_meta(kv_forward_name_));
454460
prompt_.clear();
455461
return Error::Ok;
456462
}

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ class Runner {
6767
executorch::runtime::Error load();
6868
executorch::runtime::Error generate(
6969
int32_t seq_len,
70+
int32_t num_prev_tokens,
7071
const std::string& prompt,
7172
const std::string& system_prompt,
72-
std::function<void(const std::string&)> token_callback = {},
73+
std::function<void(const std::string&, int32_t)> token_callback = {},
7374
std::function<void(const Stats&)> stats_callback = {});
7475
void stop();
7576
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>

0 commit comments

Comments
 (0)