@@ -382,7 +382,7 @@ void cpu_flash_attention(
382382 /* qk_sum */ qSplitSize +
383383 /* dst */ qSplitSize * headSize;
384384
385- int64_t size_bytes = size_per_thread * num_thread * query.element_size ();
385+ int64_t size_bytes = size_per_thread * num_thread * query.element_size () * 4 ;
386386 std::vector<char > buf_vec (size_bytes);
387387 void * buf = reinterpret_cast <void *>(buf_vec.data ());
388388 // Need to double check the following
@@ -452,6 +452,7 @@ void cpu_flash_attention(
452452 // However, lets just fix that as well.
453453 int64_t num_keys =
454454 is_causal ? std::min (m + start_pos + qBlockSize, kvSize) : kvSize;
455+ int64_t m_start_pos = m + start_pos;
455456 auto j_kv = j / num_reps;
456457 for (int64_t n = 0 ; n < num_keys; n += kvSplitSize) {
457458 int64_t kvBlockSize = std::min (kvSplitSize, kvSize - n);
@@ -471,29 +472,62 @@ void cpu_flash_attention(
471472 static_cast<accum_t>(0 ),
472473 qk_data,
473474 kvBlockSize);
474- // Apply causal mask, fill unused, i.e. future values, with -inf
475- // Say you have q @ k.T size = [16, 32]
476- // With qblock size = 4, say you are processing
477- // q seq len dim = 8:11.
478- // Say kvSplitSize = 4
479- // Then for causal mask, the entries that needs to be
480- // ignored are
481- // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
482- // Following condition says that num_keys = 8 + 4 =12
483- // (num_keys - n) <= kvSplitSize
484- // num_keys <= n + kvSplitSize
485- // If n + kvSplitSize is larger than 12, then some
486- // entries need masked out. In our example n = 4
487- // will qualify for that
488- if (is_causal && num_keys - n <= kvSplitSize) {
475+ // There are 4 cases that is_causal has to cover to fill
476+ // not-attendable-position with -inf
477+ /* 1. Everything is attended to. This happens when m_start_pos > n +
478+ kvSplitSize e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to
479+ all previous tokens matrix is full
480+ + + + + + + + +
481+ + + + + + + + +
482+ + + + + + + + +
483+ + + + + + + + +
484+ + + + + + + + +
485+ + + + + + + + +
486+ + + + + + + + +
487+ 2. Everything is not attended to. However only some tokens at the
488+ beginning dont attend to everything. This happens when m_start_pos <= n
489+ + kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize m_start_pos
490+ = 8 qBlockSize = 8 n = 4 kvSplitSize = 8 For example m_pos [8:15] but
491+ n_pos is [4:11]
492+ + + + + + - - -
493+ + + + + + + - -
494+ + + + + + + + -
495+ + + + + + + + +
496+ + + + + + + + +
497+ + + + + + + + +
498+ + + + + + + + +
499+ + + + + + + + +
500+ 3. In this case only last few tokens have something to attend to.
501+ This happens when m_start_pos < n and m_start_pos + qBlockSize >= n and
502+ m_start_pos + qBlockSize <= n + kvSplitSize m_start_pos = 8 qBlockSize =
503+ 8 n = 13 kvSplitSize = 8 For example m_pos [8:15] but n_pos is [13:20]
504+ - - - - - - - -
505+ - - - - - - - -
506+ - - - - - - - -
507+ - - - - - - - -
508+ - - - - - - - -
509+ + - - - - - - -
510+ + + - - - - - -
511+ + + + - - - - -
512+ 4. In this no tokens attend to anything, but we dont really have to
513+ take care of this case because the loop for (int64_t n = 0; n <
514+ num_keys; n += kvSplitSize) will exit before that.
515+ */
516+ if (is_causal && m_start_pos <= n + kvSplitSize) {
489517 // For this fn to work k_split_size > q_split_size
490- for (int32_t row = 0 ; row < qBlockSize; ++row) {
491- int64_t last_col = m + (row + start_pos) - n;
518+ for (int32_t row = 0 ;
519+ row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1 ));
520+ ++row) {
521+ // When last_col is 0, it means that the entire row is not attended
522+ // to because m_pos is smaller than n_pos. So everything in n is for
523+ // future.
524+ int64_t last_col =
525+ n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n;
492526 accum_t * row_ptr = qk_data + row * kvBlockSize;
493527 fill_stub (
494- row_ptr + last_col + 1 ,
528+ row_ptr + last_col,
495529 -std::numeric_limits<accum_t >::infinity (),
496- kvBlockSize - last_col - 1 );
530+ kvBlockSize - last_col);
497531 }
498532 }
499533 // Update attention weights with attention mask
0 commit comments