Skip to content

Commit c53203f

Browse files
Aya-ZIbrameta-codesync[bot]
authored andcommitted
Fix Determinstic + Left window mask (#4982)
Summary: Pull Request resolved: #4982 X-link: https://github.com/facebookresearch/FBGEMM/pull/1995 See write up for more details: https://www.internalfb.com/code/fbsource/[D84054907-V2]/blackwell_fmha_backward_deadlock_analysis.md **Before the fix** (causing deadlock): ``` Launched blocks: k=0 k=1 k=2 k=3 k=4 Participating: NO YES YES YES YES Semaphore targets: 0 1 2 3 4 Wait chain: - 0→1 1→2 2→3 3→4 ↑ ↑ HANGS (waiting for k=0 that never signals) ``` **After the fix**: ``` Launched blocks: k=0 k=1 k=2 k=3 k=4 Participating: NO YES YES YES YES Calculated targets: - 0 1 2 3 Wait chain: - wait(0) wait(1) wait(2) wait(3) Execution order: 1st 2nd 3rd 4th ``` Reviewed By: jduprat Differential Revision: D84012726 fbshipit-source-id: d70a899d0e0012aa275167a0653b8aafcafc3396
1 parent 2cf14c5 commit c53203f

File tree

1 file changed

+77
-19
lines changed

1 file changed

+77
-19
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,61 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
356356
KernelHardwareInfo hw_info;
357357
};
358358

359+
// Helper function to calculate number of previous K blocks that this block
360+
// needs to wait for
361+
template <class BlkCoord, class ProblemShape_>
362+
CUTLASS_DEVICE int calculate_participating_k_blocks(
363+
BlkCoord const& blk_coord,
364+
ProblemShape_ const& problem_shape,
365+
MainloopParams const& mainloop_params) {
366+
auto
367+
[blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] =
368+
blk_coord;
369+
370+
// For local attention, we need to calculate which K blocks actually
371+
// participate. Due to attention window properties, only early blocks can
372+
// exit, so we can loop backwards and stop at first non-participating block.
373+
if constexpr (
374+
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask> ||
375+
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>) {
376+
auto [Q, K, D, D_VO, HB] = problem_shape;
377+
378+
int total_k_blocks = ceil_div(K, TileShapeK{});
379+
int offset = 0;
380+
if constexpr (std::is_base_of_v<
381+
cutlass::fmha::collective::LocalMask<false>,
382+
Mask>) {
383+
offset = K - Q;
384+
}
385+
386+
// Loop backwards to find the first non-participating block
387+
// This is efficient because participation is contiguous after the first
388+
// participating block
389+
for (int k_blk = blk_coord_k - 1; k_blk >= 0; --k_blk) {
390+
int k_max = (k_blk + 1) * TileShapeK{};
391+
int q_max = min(Q, k_max - offset + mainloop_params.window_size_left);
392+
int iter_end_for_k = ceil_div(q_max, TileShapeQ{});
393+
394+
int k_min = k_blk * TileShapeK{};
395+
int q_min = max(0, k_min - offset - mainloop_params.window_size_right);
396+
int iter_start_for_k = q_min / (int)TileShapeQ{};
397+
398+
if (iter_end_for_k <= iter_start_for_k) {
399+
// Found first non-participating block from the end
400+
// Blocks 0 through k_blk don't participate
401+
// Blocks k_blk+1 through blk_coord_k-1 participate
402+
return blk_coord_k - 1 - k_blk;
403+
}
404+
}
405+
406+
// If we reach here, all previous blocks participate
407+
return blk_coord_k;
408+
} else {
409+
// For causal, no mask or residual mask, block x waits for x previous
410+
// blocks
411+
return blk_coord_k;
412+
}
413+
}
359414

360415
static bool can_implement(Arguments const& args) {
361416
auto [Q, K, D, D_VO, HB] = args.problem_shape;
@@ -367,6 +422,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
367422
if (D % Alignment != 0 || D_VO % Alignment != 0) {
368423
return false;
369424
}
425+
370426
return true;
371427
}
372428

@@ -1498,23 +1554,26 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
14981554
? nullptr
14991555
: (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R);
15001556

1501-
// When IsDeterministic is true, we require each thread block to iterate
1502-
// over every K block. This ensures that the semaphore flag is incremented
1503-
// exactly K times, matching the block's K coordinate. This approach is
1504-
// conservative; we could optimize it by calculating the actual number of
1505-
// thread blocks participating in the reduction and adjusting the target
1506-
// value (blk_coord_k) accordingly.
1557+
// Calculate the actual number of participating K blocks for deterministic
1558+
// mode
1559+
int barrier_target = blk_coord_k; // Default for backward compatibility
1560+
if constexpr (IsDeterministic) {
1561+
barrier_target = calculate_participating_k_blocks(
1562+
blk_coord, problem_shape, mainloop_params);
1563+
}
1564+
15071565
auto full_iter_count = IsDeterministic ? max_iter_count : iter_count;
15081566
auto full_iter_index = 0;
15091567

15101568
while (full_iter_count > 0) {
15111569
if constexpr (IsDeterministic) {
1512-
// Wait until the semaphore flag become blk_coord_k
1513-
Barrier::wait_eq(
1514-
lock_ptr,
1515-
thread_idx,
1516-
full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch),
1517-
blk_coord_k);
1570+
// Wait until the semaphore flag reaches the actual number of
1571+
// participating blocks
1572+
Barrier::wait_eq(
1573+
lock_ptr,
1574+
thread_idx,
1575+
full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch),
1576+
barrier_target);
15181577
}
15191578
if (!IsDeterministic || (full_iter_index >= iter_start && full_iter_index < iter_end)) {
15201579
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
@@ -1799,14 +1858,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
17991858
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
18001859
int offset = get<1>(problem_shape) - get<0>(problem_shape);
18011860
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
1802-
}
1803-
else if constexpr (
1861+
} else if constexpr (
18041862
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask> ||
1805-
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask>
1806-
) {
1807-
int offset = std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>
1808-
? get<1>(problem_shape) - get<0>(problem_shape)
1809-
: 0;
1863+
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask>) {
1864+
int offset =
1865+
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>
1866+
? get<1>(problem_shape) - get<0>(problem_shape)
1867+
: 0;
18101868

18111869
int k_max = (get<1>(blk_coord) + 1) * TileShapeK{};
18121870
int q_max = min(get<0>(problem_shape), k_max - offset + params.mainloop_params.window_size_left);

0 commit comments

Comments
 (0)