@@ -356,6 +356,61 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
356
356
KernelHardwareInfo hw_info;
357
357
};
358
358
359
+ // Helper function to calculate number of previous K blocks that this block
360
+ // needs to wait for
361
+ template <class BlkCoord , class ProblemShape_ >
362
+ CUTLASS_DEVICE int calculate_participating_k_blocks (
363
+ BlkCoord const & blk_coord,
364
+ ProblemShape_ const & problem_shape,
365
+ MainloopParams const & mainloop_params) {
366
+ auto
367
+ [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] =
368
+ blk_coord;
369
+
370
+ // For local attention, we need to calculate which K blocks actually
371
+ // participate. Due to attention window properties, only early blocks can
372
+ // exit, so we can loop backwards and stop at first non-participating block.
373
+ if constexpr (
374
+ std::is_base_of_v<cutlass::fmha::collective::LocalMask<true >, Mask> ||
375
+ std::is_base_of_v<cutlass::fmha::collective::LocalMask<false >, Mask>) {
376
+ auto [Q, K, D, D_VO, HB] = problem_shape;
377
+
378
+ int total_k_blocks = ceil_div (K, TileShapeK{});
379
+ int offset = 0 ;
380
+ if constexpr (std::is_base_of_v<
381
+ cutlass::fmha::collective::LocalMask<false >,
382
+ Mask>) {
383
+ offset = K - Q;
384
+ }
385
+
386
+ // Loop backwards to find the first non-participating block
387
+ // This is efficient because participation is contiguous after the first
388
+ // participating block
389
+ for (int k_blk = blk_coord_k - 1 ; k_blk >= 0 ; --k_blk) {
390
+ int k_max = (k_blk + 1 ) * TileShapeK{};
391
+ int q_max = min (Q, k_max - offset + mainloop_params.window_size_left );
392
+ int iter_end_for_k = ceil_div (q_max, TileShapeQ{});
393
+
394
+ int k_min = k_blk * TileShapeK{};
395
+ int q_min = max (0 , k_min - offset - mainloop_params.window_size_right );
396
+ int iter_start_for_k = q_min / (int )TileShapeQ{};
397
+
398
+ if (iter_end_for_k <= iter_start_for_k) {
399
+ // Found first non-participating block from the end
400
+ // Blocks 0 through k_blk don't participate
401
+ // Blocks k_blk+1 through blk_coord_k-1 participate
402
+ return blk_coord_k - 1 - k_blk;
403
+ }
404
+ }
405
+
406
+ // If we reach here, all previous blocks participate
407
+ return blk_coord_k;
408
+ } else {
409
+ // For causal, no mask or residual mask, block x waits for x previous
410
+ // blocks
411
+ return blk_coord_k;
412
+ }
413
+ }
359
414
360
415
static bool can_implement (Arguments const & args) {
361
416
auto [Q, K, D, D_VO, HB] = args.problem_shape ;
@@ -367,6 +422,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
367
422
if (D % Alignment != 0 || D_VO % Alignment != 0 ) {
368
423
return false ;
369
424
}
425
+
370
426
return true ;
371
427
}
372
428
@@ -1498,23 +1554,26 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
1498
1554
? nullptr
1499
1555
: (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R);
1500
1556
1501
- // When IsDeterministic is true, we require each thread block to iterate
1502
- // over every K block. This ensures that the semaphore flag is incremented
1503
- // exactly K times, matching the block's K coordinate. This approach is
1504
- // conservative; we could optimize it by calculating the actual number of
1505
- // thread blocks participating in the reduction and adjusting the target
1506
- // value (blk_coord_k) accordingly.
1557
+ // Calculate the actual number of participating K blocks for deterministic
1558
+ // mode
1559
+ int barrier_target = blk_coord_k; // Default for backward compatibility
1560
+ if constexpr (IsDeterministic) {
1561
+ barrier_target = calculate_participating_k_blocks (
1562
+ blk_coord, problem_shape, mainloop_params);
1563
+ }
1564
+
1507
1565
auto full_iter_count = IsDeterministic ? max_iter_count : iter_count;
1508
1566
auto full_iter_index = 0 ;
1509
1567
1510
1568
while (full_iter_count > 0 ) {
1511
1569
if constexpr (IsDeterministic) {
1512
- // Wait until the semaphore flag become blk_coord_k
1513
- Barrier::wait_eq (
1514
- lock_ptr,
1515
- thread_idx,
1516
- full_iter_index * H_R * H_K * B + get<0 , 0 >(blk_coord_batch),
1517
- blk_coord_k);
1570
+ // Wait until the semaphore flag reaches the actual number of
1571
+ // participating blocks
1572
+ Barrier::wait_eq (
1573
+ lock_ptr,
1574
+ thread_idx,
1575
+ full_iter_index * H_R * H_K * B + get<0 , 0 >(blk_coord_batch),
1576
+ barrier_target);
1518
1577
}
1519
1578
if (!IsDeterministic || (full_iter_index >= iter_start && full_iter_index < iter_end)) {
1520
1579
pipeline_mma_reduce_dq.consumer_wait (pipeline_mma_reduce_dq_consumer_state);
@@ -1799,14 +1858,13 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
1799
1858
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false >, Mask>) {
1800
1859
int offset = get<1 >(problem_shape) - get<0 >(problem_shape);
1801
1860
iter_start = max (0 , (int (get<1 >(blk_coord) * TileShapeK{}) - offset) / (int )TileShapeQ{});
1802
- }
1803
- else if constexpr (
1861
+ } else if constexpr (
1804
1862
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false >, Mask> ||
1805
- std::is_base_of_v<cutlass::fmha::collective::LocalMask<true >, Mask>
1806
- ) {
1807
- int offset = std::is_base_of_v<cutlass::fmha::collective::LocalMask<false >, Mask>
1808
- ? get<1 >(problem_shape) - get<0 >(problem_shape)
1809
- : 0 ;
1863
+ std::is_base_of_v<cutlass::fmha::collective::LocalMask<true >, Mask>) {
1864
+ int offset =
1865
+ std::is_base_of_v<cutlass::fmha::collective::LocalMask<false >, Mask>
1866
+ ? get<1 >(problem_shape) - get<0 >(problem_shape)
1867
+ : 0 ;
1810
1868
1811
1869
int k_max = (get<1 >(blk_coord) + 1 ) * TileShapeK{};
1812
1870
int q_max = min (get<0 >(problem_shape), k_max - offset + params.mainloop_params .window_size_left );
0 commit comments