Skip to content

Commit 01c1880

Browse files
henrylhtsangmeta-codesync[bot]
authored andcommitted
bwd case when offset is 0 (#4963)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1982 Pull Request resolved: #4963 This case is for when we are not using bottom right mask. It should be slightly better perf in that case. # notes We note that backward is in general not stable. Sometimes you can get IMA. And numerics are not as good as we want it to be. Reviewed By: q10 Differential Revision: D83076701 fbshipit-source-id: a1016b15a86d10f21d962166eae036d959befe18
1 parent 4a88713 commit 01c1880

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,8 +1799,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
17991799
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
18001800
int offset = get<1>(problem_shape) - get<0>(problem_shape);
18011801
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
1802-
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>) {
1803-
int offset = get<1>(problem_shape) - get<0>(problem_shape);
1802+
}
1803+
else if constexpr (
1804+
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask> ||
1805+
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask>
1806+
) {
1807+
int offset = std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>
1808+
? get<1>(problem_shape) - get<0>(problem_shape)
1809+
: 0;
18041810

18051811
int k_max = (get<1>(blk_coord) + 1) * TileShapeK{};
18061812
int q_max = min(get<0>(problem_shape), k_max - offset + params.mainloop_params.window_size_left);

0 commit comments

Comments
 (0)