Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
window_size_right = max_seq_len_k.value();
}
}

if (causal) {
window_size_right = 0;
}
auto dispatch_fmha = [&](auto element,
auto element_out,
auto head_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
window_size_right = max_seq_len_k.value();
}
}
if (causal){
window_size_right = 0;
}

auto dispatch_fmha = [&](auto element,
auto element_out,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,62 +354,43 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
MainloopParams mainloop_params;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
// L2 cache swizzle parameters (calculated on host)
int kSwizzle;
int num_blocks_k;
int num_hb_quotient;
int num_hb_remainder;
cutlass::FastDivmod l2_major_divmod;
cutlass::FastDivmod l2_minor_divmod;
cutlass::FastDivmod head_divmod;
cutlass::FastDivmod l2_minor_residual_divmod;
};

// Helper function to calculate number of previous K blocks that this block
// needs to wait for
template <class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE int calculate_participating_k_blocks(
BlkCoord const& blk_coord,
template <class ProblemShape_>
CUTLASS_DEVICE int compute_expected_turn(
int iter_index,
int block_k,
ProblemShape_ const& problem_shape,
MainloopParams const& mainloop_params) {
auto
[blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] =
blk_coord;

// For local attention, we need to calculate which K blocks actually
// participate. Due to attention window properties, only early blocks can
// exit, so we can loop backwards and stop at first non-participating block.
// If mask is causal or local, reverse ordering of reduction
if constexpr (
std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>) {
auto [Q, K, D, D_VO, HB] = problem_shape;

int total_k_blocks = ceil_div(K, TileShapeK{});
int offset = 0;
if constexpr (std::is_base_of_v<
cutlass::fmha::collective::LocalMask<false>,
Mask>) {
offset = K - Q;
}

// Loop backwards to find the first non-participating block
// This is efficient because participation is contiguous after the first
// participating block
for (int k_blk = blk_coord_k - 1; k_blk >= 0; --k_blk) {
int k_max = (k_blk + 1) * TileShapeK{};
int q_max = min(Q, k_max - offset + mainloop_params.window_size_left);
int iter_end_for_k = ceil_div(q_max, TileShapeQ{});

int k_min = k_blk * TileShapeK{};
int q_min = max(0, k_min - offset - mainloop_params.window_size_right);
int iter_start_for_k = q_min / (int)TileShapeQ{};

if (iter_end_for_k <= iter_start_for_k) {
// Found first non-participating block from the end
// Blocks 0 through k_blk don't participate
// Blocks k_blk+1 through blk_coord_k-1 participate
return blk_coord_k - 1 - k_blk;
if (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>){
offset = (get<1>(problem_shape) - get<0>(problem_shape));
}
}

// If we reach here, all previous blocks participate
return blk_coord_k;
} else {
// For causal, no mask or residual mask, block x waits for x previous
// blocks
return blk_coord_k;
}
int k_global_max = cute::ceil_div(get<1>(problem_shape) , TileShapeK{});
int k_max_for_q_block = std::min(
k_global_max,
cute::ceil_div((iter_index + 1) * TileShapeQ{} + offset + mainloop_params.window_size_right
, TileShapeK{}));
int last_k_block = k_max_for_q_block - 1;
return last_k_block - block_k;
}
return block_k;
}

static bool can_implement(Arguments const& args) {
Expand Down Expand Up @@ -464,6 +445,34 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
SmemLayoutDQ{}(_, _, _0{})
);

auto [H, B] = HB;
auto [H_R, H_K] = H;

long long const size_one_qdo_head =
long(K_) * long(D + D_VO) * long(sizeof(Element));
long long const size_one_dqaccum_head = long(K_) * long(D) * sizeof(float);
long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head;

int l2_cache_size = 0;
cudaDeviceGetAttribute(
&l2_cache_size, cudaDevAttrL2CacheSize, args.hw_info.device_id);
int const size_l2_reserved = static_cast<int>(l2_cache_size * 0.8);

auto find_log2_floor = [](int n) { return 31 - cutlass::clz(n); };
int const kSwizzle = size_l2_reserved < size_one_head
? 1
: (1 << find_log2_floor(size_l2_reserved / size_one_head));
int num_blocks_k = ceil_div(K_, TileShapeK{});
int total_heads_batches = H_K * B;
int num_hb_quotient = total_heads_batches / kSwizzle;
int num_hb_remainder = total_heads_batches % kSwizzle;

cutlass::FastDivmod l2_major_divmod(kSwizzle * num_blocks_k);
cutlass::FastDivmod l2_minor_divmod(kSwizzle);
cutlass::FastDivmod head_divmod(H_K);
cutlass::FastDivmod l2_minor_residual_divmod(
num_hb_remainder > 0 ? num_hb_remainder : 1);

return Params{
args.problem_shape,
args.mainloop,
Expand All @@ -476,9 +485,17 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
args.mainloop.window_size_left,
args.mainloop.window_size_right
},
args.epilogue,
args.hw_info
};
args.epilogue,
args.hw_info,
kSwizzle,
num_blocks_k,
num_hb_quotient,
num_hb_remainder,
l2_major_divmod,
l2_minor_divmod,
head_divmod,
l2_minor_residual_divmod
};
}


Expand Down Expand Up @@ -1506,10 +1523,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state,
int max_iter_count,
int max_iter_end) {

typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state
) {
using X = Underscore;

auto [Q, K, D, D_VO, HB] = problem_shape;
Expand Down Expand Up @@ -1552,32 +1567,28 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
int *lock_ptr = !IsDeterministic
? nullptr
: (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R);

// Calculate the actual number of participating K blocks for deterministic
// mode
int barrier_target = blk_coord_k; // Default for backward compatibility
if constexpr (IsDeterministic) {
barrier_target = calculate_participating_k_blocks(
blk_coord, problem_shape, mainloop_params);
}

auto full_iter_count = IsDeterministic ? max_iter_count : iter_count;
auto full_iter_index = 0;
? nullptr
: (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R);

int expected_turn = 0;
// Optimized: Only iterate over Q blocks this K block actually processes
while (iter_count > 0) {

while (full_iter_count > 0) {
if constexpr (IsDeterministic) {
// Wait until the semaphore flag reaches the actual number of
// participating blocks
expected_turn = compute_expected_turn(
iter_index,
blk_coord_k,
problem_shape,
mainloop_params);
Barrier::wait_eq(
lock_ptr,
thread_idx,
full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch),
barrier_target);
lock_ptr,
thread_idx,
iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch),
expected_turn);
}
if (!IsDeterministic || (full_iter_index >= iter_start && full_iter_index < iter_end)) {
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
{
pipeline_mma_reduce_dq.consumer_wait(
pipeline_mma_reduce_dq_consumer_state);

Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));

Expand Down Expand Up @@ -1609,41 +1620,38 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
// TMA REDUCE ADD - atomic operation to global memory
copy(
mainloop_params.tma_red_dq,
tDQsDQ(
_,
_,
_0{},
pipeline_reduce_tma_store_producer_state.index()),
tDQgDQ(_, _, i, iter_index, blk_coord_batch));
/// tma_store_arrive();
pipeline_reduce_tma_store.producer_commit(
pipeline_reduce_tma_store_producer_state);
}

++pipeline_reduce_tma_store_producer_state;
}

// Update iter index
iter_index += 1;
}

}
if constexpr (IsDeterministic) {
// Increment the semaphore flag
Barrier::arrive_inc(
lock_ptr,
lock_ptr,
thread_idx,
full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch));

full_iter_index += 1;
iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch));

if (full_iter_index == max_iter_end) {
iter_index = iter_start;
full_iter_index = 0;
get<0,0>(blk_coord_batch) += 1;
}
}
else {
if (iter_index == iter_end) {
iter_index += 1;
if (iter_index == iter_end) {
iter_index = iter_start;
get<0,0>(blk_coord_batch) += 1;
}
}
}

full_iter_count -= 1;
iter_count -= 1;
}
}

Expand Down Expand Up @@ -1850,14 +1858,51 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {

pipeline_init_wait(size(ClusterShape{}));

auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z));
// Head swizzling: Decompose tile_idx to (block_k, head_idx, batch_idx)

int tile_idx = blockIdx.x;
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;

// Use pre-calculated swizzle parameters from host
int const kSwizzle = params.kSwizzle;
int num_blocks_k = params.num_blocks_k;
int num_hb_quotient = params.num_hb_quotient;

// Step 1: Which section (bidhb) and position within section (l2_mod)
int bidhb, l2_mod;
bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx);

// Step 2: Within section, get block_k and head-batch offset
int block_k, bidhb_residual;

if (bidhb < num_hb_quotient) {
block_k = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod);
} else {
block_k = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod);
}
// Step 3: Convert to actual head and batch indices
int head_batch_idx = bidhb * kSwizzle + bidhb_residual;
int batch_idx, head_idx;
batch_idx = params.head_divmod.divmod(head_idx, head_batch_idx);

if constexpr (
std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>) {
// Reverse block_k ordering (for SPT scheduling)
block_k = num_blocks_k - 1 - block_k;
}

auto blk_coord = make_coord(_0{}, block_k, _0{}, _0{}, make_coord(make_coord(0, head_idx), batch_idx));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
int max_iter_end = IsDeterministic ? iter_end : 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
Expand All @@ -1884,7 +1929,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}

int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape);
int max_iter_count = IsDeterministic ? max_iter_end * get<4,0,0>(problem_shape) : 0;

if (iter_count <= 0) {
epilogue_clear(
Expand Down Expand Up @@ -1987,13 +2031,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state,
max_iter_count, max_iter_end
);

pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);
}
pipeline_mma_reduce_dq,
pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store,
pipeline_reduce_tma_store_producer_state);

pipeline_reduce_tma_store.producer_tail(
pipeline_reduce_tma_store_producer_state);
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();

Expand All @@ -2012,7 +2057,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
dim3 grid(ceil_div(K, TileShapeK{}), H_K, B);

int num_blocks_k = ceil_div(K, TileShapeK{});
int total_heads_batches = H_K * B;
int total_tiles = num_blocks_k * total_heads_batches;

dim3 grid(total_tiles, 1, 1);
return grid;
}
};
Expand Down
Loading