@@ -382,7 +382,7 @@ void cpu_flash_attention(
382382 /* qk_sum */ qSplitSize +
383383 /* dst */ qSplitSize * headSize;
384384
385- int64_t size_bytes = size_per_thread * num_thread * query.element_size ();
385+ int64_t size_bytes = size_per_thread * num_thread * query.element_size () * 4 ;
386386 std::vector<char > buf_vec (size_bytes);
387387 void * buf = reinterpret_cast <void *>(buf_vec.data ());
388388 // Need to double check the following
@@ -452,6 +452,7 @@ void cpu_flash_attention(
452452 // However, lets just fix that as well.
453453 int64_t num_keys =
454454 is_causal ? std::min (m + start_pos + qBlockSize, kvSize) : kvSize;
455+ int64_t m_start_pos = m + start_pos;
455456 auto j_kv = j / num_reps;
456457 for (int64_t n = 0 ; n < num_keys; n += kvSplitSize) {
457458 int64_t kvBlockSize = std::min (kvSplitSize, kvSize - n);
@@ -471,29 +472,58 @@ void cpu_flash_attention(
471472 static_cast<accum_t>(0 ),
472473 qk_data,
473474 kvBlockSize);
474- // Apply causal mask, fill unused, i.e. future values, with -inf
475- // Say you have q @ k.T size = [16, 32]
476- // With qblock size = 4, say you are processing
477- // q seq len dim = 8:11.
478- // Say kvSplitSize = 4
479- // Then for causal mask, the entries that needs to be
480- // ignored are
481- // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
482- // Following condition says that num_keys = 8 + 4 =12
483- // (num_keys - n) <= kvSplitSize
484- // num_keys <= n + kvSplitSize
485- // If n + kvSplitSize is larger than 12, then some
486- // entries need masked out. In our example n = 4
487- // will qualify for that
488- if (is_causal && num_keys - n <= kvSplitSize) {
475+ // There are 4 cases that is_causal has to cover to fill not-attendable-position with -inf
476+ /* 1. Everything is attended to. This happens when m_start_pos > n + kvSplitSize
477+ e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to all previous tokens
478+ matrix is full
479+ + + + + + + + +
480+ + + + + + + + +
481+ + + + + + + + +
482+ + + + + + + + +
483+ + + + + + + + +
484+ + + + + + + + +
485+ + + + + + + + +
486+ 2. Everything is not attended to. However only some tokens at the beginning dont attend
487+ to everything. This happens when m_start_pos <= n + kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize
488+ m_start_pos = 8 qBlockSize = 8
489+ n = 4 kvSplitSize = 8
490+ For example m_pos [8:15] but n_pos is [4:11]
491+ + + + + + - - -
492+ + + + + + + - -
493+ + + + + + + + -
494+ + + + + + + + +
495+ + + + + + + + +
496+ + + + + + + + +
497+ + + + + + + + +
498+ + + + + + + + +
499+ 3. In this case only last few tokens have something to attend to. This happens when m_start_pos < n
500+ and m_start_pos + qBlockSize >= n
501+ and m_start_pos + qBlockSize <= n + kvSplitSize
502+ m_start_pos = 8 qBlockSize = 8
503+ n = 13 kvSplitSize = 8
504+ For example m_pos [8:15] but n_pos is [13:20]
505+ - - - - - - - -
506+ - - - - - - - -
507+ - - - - - - - -
508+ - - - - - - - -
509+ - - - - - - - -
510+ + - - - - - - -
511+ + + - - - - - -
512+ + + + - - - - -
513+ 4. In this no tokens attend to anything, but we dont really have to take care of this case because
514+ the loop for (int64_t n = 0; n < num_keys; n += kvSplitSize) will exit before that.
515+ */
516+ if (is_causal && m_start_pos <= n + kvSplitSize) {
489517 // For this fn to work k_split_size > q_split_size
490- for (int32_t row = 0 ; row < qBlockSize; ++row) {
491- int64_t last_col = m + (row + start_pos) - n;
518+ for (int32_t row = 0 ; row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1 )); ++row) {
519+ // When last_col is 0, it means that the entire row is not attended to
520+ // because m_pos is smaller than n_pos. So everything in n is for future.
521+ int64_t last_col = n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n;
492522 accum_t * row_ptr = qk_data + row * kvBlockSize;
493523 fill_stub (
494- row_ptr + last_col + 1 ,
524+ row_ptr + last_col,
495525 -std::numeric_limits<accum_t >::infinity (),
496- kvBlockSize - last_col - 1 );
526+ kvBlockSize - last_col);
497527 }
498528 }
499529 // Update attention weights with attention mask
0 commit comments