66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < algorithm>
10- #include < fstream>
11-
129#include < executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
1310#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
11+ #include < algorithm>
1412
1513using executorch::aten::Tensor;
1614using executorch::aten::TensorImpl;
@@ -55,7 +53,8 @@ std::vector<Tensor> Memory::get_output_tensors(
5553
5654HybridMemory::HybridMemory (
5755 std::vector<std::shared_ptr<Module>>& modules,
58- int32_t max_seq_len,
56+ int32_t prefill_cache_len,
57+ int32_t kv_cache_len,
5958 int32_t vocab_size,
6059 int32_t num_layers,
6160 int32_t head_dim,
@@ -65,7 +64,8 @@ HybridMemory::HybridMemory(
6564 const std::string& kv_forward_name)
6665 : Memory(modules),
6766 shard_layers_ ({num_layers}),
68- max_seq_len_(max_seq_len),
67+ prefill_cache_len_(prefill_cache_len),
68+ kv_cache_len_(kv_cache_len),
6969 vocab_size_(vocab_size),
7070 num_layers_(num_layers),
7171 head_dim_(head_dim),
@@ -106,17 +106,17 @@ HybridMemory::HybridMemory(
106106 new IO, [](void * ptr) { delete static_cast <IO*>(ptr); });
107107}
108108
109- void HybridMemory::init_io (
110- const std::vector<executorch::runtime::Result<
111- executorch::runtime::MethodMeta>>& methods_meta,
112- EvalMode eval_mode) {
109+ void HybridMemory::init_io () {
113110 IO* ptr = static_cast <IO*>(data_ptr_.get ());
114111 std::memset (ptr, 0 , sizeof (IO));
115112
116- int32_t cache_len = max_seq_len_ - 1 ;
117- int32_t k_in_size = (head_dim_ + 1 ) * (max_seq_len_ - 1 );
118- int32_t k_cache_out_size = num_heads_ * head_dim_ * cache_len;
119- int32_t v_cache_size = (num_heads_ + 1 ) * (max_seq_len_ - 1 ) * head_dim_;
113+ int32_t max_cache_len = std::max (kv_cache_len_, prefill_cache_len_);
114+ int32_t k_in_size = (head_dim_ + 1 ) * max_cache_len;
115+ int32_t v_cache_size = (num_heads_ + 1 ) * max_cache_len * head_dim_;
116+ int32_t k_cache_out_size = num_heads_ * head_dim_;
117+ if (eval_mode_ == EvalMode::kHybrid || eval_mode_ == EvalMode::kPrefill ) {
118+ k_cache_out_size *= prefill_cache_len_;
119+ }
120120
121121 // Init kv vector shape, general enough to be shared across all 3 modes.
122122 ptr->k_cache_out .reserve (num_layers_);
@@ -127,14 +127,14 @@ void HybridMemory::init_io(
127127 }
128128
129129 auto init_prefill = [&]() {
130- ptr->prefill_input_toks .resize (cache_len );
131- ptr->prefill_atten_mask .resize (cache_len * cache_len );
132- ptr->prefill_logits .resize (cache_len * vocab_size_);
130+ ptr->prefill_input_toks .resize (prefill_cache_len_ );
131+ ptr->prefill_atten_mask .resize (prefill_cache_len_ * prefill_cache_len_ );
132+ ptr->prefill_logits .resize (prefill_cache_len_ * vocab_size_);
133133 };
134134
135135 auto init_kv = [&]() {
136136 ptr->kv_logits .resize (vocab_size_);
137- ptr->kv_attention_mask .resize (max_seq_len_ , -255 );
137+ ptr->kv_attention_mask .resize ((kv_cache_len_ + 1 ) , -255 );
138138 ptr->k_cache .reserve (num_layers_);
139139 for (int layer = 0 ; layer < num_layers_; layer++) {
140140 ptr->k_cache .emplace_back ();
@@ -145,7 +145,7 @@ void HybridMemory::init_io(
145145 }
146146 };
147147
148- switch (eval_mode ) {
148+ switch (eval_mode_ ) {
149149 case EvalMode::kPrefill :
150150 init_prefill ();
151151 break ;
@@ -205,9 +205,7 @@ void HybridMemory::prepare_kv_io(
205205
206206 // [I] kv_cache
207207 int index = 3 ; // bypass input_tokens, input_pos, atten_mask
208- for (int offset = 0 ,
209- shard_index = 0 ,
210- v_stride = (max_seq_len_ - 1 ) * head_dim_;
208+ for (int offset = 0 , shard_index = 0 , v_stride = kv_cache_len_ * head_dim_;
211209 shard_index < modules_.size ();
212210 offset += shard_layers_[shard_index], shard_index++) {
213211 for (int cache_group = 0 ; cache_group < 2 ; ++cache_group) {
@@ -256,9 +254,7 @@ void HybridMemory::prepare_kv_io(
256254 // For k, we store it in k_cache_out and update to k_cache later.
257255 // For v, we append the output to the end of v_cache,
258256 // which serves as both input and output.
259- for (int offset = 0 ,
260- shard_index = 0 ,
261- v_stride = (max_seq_len_ - 1 ) * head_dim_;
257+ for (int offset = 0 , shard_index = 0 , v_stride = kv_cache_len_ * head_dim_;
262258 shard_index < modules_.size ();
263259 offset += shard_layers_[shard_index], shard_index++) {
264260 for (int cache_group = 0 ; cache_group < 2 ; ++cache_group) {
@@ -305,8 +301,6 @@ void HybridMemory::prepare_prefill_io(
305301
306302 IO* ptr = static_cast <IO*>(data_ptr_.get ());
307303
308- // cache_len should be max_seq_len - 1
309- int32_t cache_len = methods_meta[0 ]->input_tensor_meta (0 )->sizes ()[1 ];
310304 // [I]: pre_input_tokens
311305 Result<TensorInfo> prefill_input_toks = methods_meta[0 ]->input_tensor_meta (0 );
312306 prefill_input_toks_ = std::make_unique<TensorImpl>(
@@ -318,12 +312,12 @@ void HybridMemory::prepare_prefill_io(
318312 prefill_input_toks->dim_order ().data ()));
319313 input_tensors_[prefill_forward_name_][0 ].push_back (prefill_input_toks_.get ());
320314 // [I]: prefill_attn_mask
321- for (int i = 0 ; i < cache_len ; ++i) {
322- for (int j = 0 ; j < cache_len ; ++j) {
315+ for (int i = 0 ; i < prefill_cache_len_ ; ++i) {
316+ for (int j = 0 ; j < prefill_cache_len_ ; ++j) {
323317 if (i < j) {
324- ptr->prefill_atten_mask [i * cache_len + j] = -255 ;
318+ ptr->prefill_atten_mask [i * prefill_cache_len_ + j] = -255 ;
325319 } else {
326- ptr->prefill_atten_mask [i * cache_len + j] = 0 ;
320+ ptr->prefill_atten_mask [i * prefill_cache_len_ + j] = 0 ;
327321 }
328322 }
329323 }
@@ -347,10 +341,22 @@ void HybridMemory::prepare_prefill_io(
347341 const_cast <TensorImpl::DimOrderType*>(logits->dim_order ().data ()));
348342 output_tensors_[prefill_forward_name_][modules_.size () - 1 ].push_back (
349343 prefill_logits_.get ());
344+
350345 // [O] kv_cache
351346 int index = 1 ;
352- for (int offset = 0 , shard_index = 0 , cache_stride = cache_len * head_dim_;
353- shard_index < modules_.size ();
347+ // prefill_k_stride should be equal to prefill_v_stride in prefill mode.
348+ // In hybrid mode, we use kv mode cache len for v stride since we want to
349+ // update prefill's result onto kv modes input.
350+ int32_t prefill_k_stride = prefill_cache_len_ * head_dim_;
351+ int32_t prefill_v_stride =
352+ std::max (prefill_cache_len_, kv_cache_len_) * head_dim_;
353+
354+ if (eval_mode_ == EvalMode::kPrefill ) {
355+ ET_CHECK_MSG (
356+ prefill_k_stride == prefill_v_stride,
357+ " prefill_k_stride should be equal to prefill_v_stride" );
358+ }
359+ for (int offset = 0 , shard_index = 0 ; shard_index < modules_.size ();
354360 offset += shard_layers_[shard_index], shard_index++) {
355361 for (int cache_group = 0 ; cache_group < 2 ; ++cache_group) {
356362 for (int layer = 0 ; layer < shard_layers_[shard_index]; ++layer) {
@@ -363,10 +369,10 @@ void HybridMemory::prepare_prefill_io(
363369 void * cache_ptr = (cache_group == 0 )
364370 ? static_cast <void *>(
365371 ptr->k_cache_out [layer + offset].data () +
366- head * cache_stride )
372+ head * prefill_k_stride )
367373 : static_cast <void *>(
368374 ptr->v_cache [layer + offset].data () +
369- (head + 1 ) * cache_stride );
375+ (head + 1 ) * prefill_v_stride );
370376 cache.emplace_back (std::make_unique<TensorImpl>(
371377 kv_cache->scalar_type (),
372378 kv_cache->sizes ().size (),
@@ -386,15 +392,17 @@ void HybridMemory::update_prefill_to_kv_io(
386392 int64_t cur_token,
387393 int64_t pos,
388394 std::vector<std::vector<Tensor>>& output_tensors) {
389- int cache_len = (max_seq_len_ - 1 );
395+ ET_CHECK_MSG (kv_cache_len_ != 0 , " k_cache_len_ should not equal to 0" );
396+ ET_CHECK_MSG (
397+ prefill_cache_len_ != 0 , " prefill_cache_len_ should not equal to 0" );
390398 IO* ptr = static_cast <IO*>(data_ptr_.get ());
391399
392400 ptr->input_tok = static_cast <int32_t >(cur_token);
393401 ptr->input_pos = static_cast <int32_t >(pos);
394402 // If prompt len is 30, prefill will handle to pos = 30.
395403 // At this point, pos should be 31.
396404 for (int i = 0 ; i < pos + 1 ; i++) {
397- ptr->kv_attention_mask [cache_len - i] = 0 ;
405+ ptr->kv_attention_mask [kv_cache_len_ - i] = 0 ;
398406 }
399407
400408 // update v_cache
@@ -429,9 +437,9 @@ void HybridMemory::update_prefill_to_kv_io(
429437 for (int i = 0 ; i < k_cache_in.size (); ++i) {
430438 uint8_t * ptr_in = k_cache_in[i]->mutable_data <uint8_t >();
431439 const uint8_t * ptr_out = k_cache_out[i]->data <uint8_t >();
432- for (size_t j = 0 , offset = cache_len ; j < head_dim_;
433- ++j, offset += cache_len ) {
434- for (int k = 0 , k_stride = j * cache_len ; k < pos; k++) {
440+ for (size_t j = 0 , offset = kv_cache_len_ ; j < head_dim_;
441+ ++j, offset += kv_cache_len_ ) {
442+ for (int k = 0 , k_stride = j * prefill_cache_len_ ; k < pos; k++) {
435443 ptr_in[offset + k] = ptr_out[k_stride + k];
436444 }
437445 }
@@ -444,13 +452,12 @@ void HybridMemory::update_kv_io(
444452 int64_t pos,
445453 std::vector<std::vector<Tensor>>& output_tensors) {
446454 IO* ptr = static_cast <IO*>(data_ptr_.get ());
447- int seq_len = (max_seq_len_ - 1 );
448455 // update input_tok
449456 ptr->input_tok = static_cast <int32_t >(cur_token);
450457 // update position_ids
451458 ptr->input_pos = static_cast <int32_t >(pos);
452459 // update causal mask for next token
453- ptr->kv_attention_mask [seq_len - pos] = 0 ;
460+ ptr->kv_attention_mask [kv_cache_len_ - pos] = 0 ;
454461
455462 // update v_cache
456463 auto & v_cache_in = v_cache_in_[kv_forward_name_];
@@ -480,8 +487,8 @@ void HybridMemory::update_kv_io(
480487 for (int i = 0 ; i < k_cache_in.size (); ++i) {
481488 uint8_t * ptr_in = k_cache_in[i]->mutable_data <uint8_t >();
482489 const uint8_t * ptr_out = k_cache_out[i]->data <uint8_t >();
483- for (size_t j = 0 , offset = seq_len ; j < head_dim_;
484- ++j, offset += seq_len ) {
490+ for (size_t j = 0 , offset = kv_cache_len_ ; j < head_dim_;
491+ ++j, offset += kv_cache_len_ ) {
485492 ptr_in[offset] = ptr_out[j];
486493 }
487494 k_cache_in[i]->set_data (ptr_in + 1 );
0 commit comments