Skip to content

Commit 6457918

Browse files
Aya-ZIbrarichardmcaihwu36
authored
Feature/add bottom causal mask (#2480)
* Rebase to latest * update * upd Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update fmha_fusion.hpp * Update fmha_fusion.hpp fixed flipped logic for isQBegin * Update fmha_fusion.hpp * Avoid use of booleans The current expression is confusing * fmt * Update fmha_fusion.hpp Reproduce error/fix with: ./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend * add test, format --------- Co-authored-by: Richard Cai <ricai@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
1 parent b234a8c commit 6457918

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

examples/77_blackwell_fmha/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ set_property(
3939
)
4040

4141
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
42-
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
42+
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
43+
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
4344
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
4445
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
4546
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
@@ -119,7 +120,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
119120
77_blackwell_fmha.cu
120121
TEST_COMMAND_OPTIONS
121122
TEST_BASIC
122-
TEST_CAUSAL
123+
TEST_CAUSAL_00
124+
TEST_CAUSAL_01
123125
TEST_VARLEN
124126
TEST_HDIM64
125127
TEST_GQA
@@ -222,7 +224,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
222224
77_blackwell_mla_fwd.cu
223225
TEST_COMMAND_OPTIONS
224226
TEST_BASIC
225-
TEST_CAUSAL
227+
TEST_CAUSAL_00
226228
TEST_VARLEN
227229
TEST_HDIM64
228230
TEST_GQA

examples/77_blackwell_fmha/collective/fmha_fusion.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ struct CausalMask : NoMask {
225225
if constexpr (IsQBegin) {
226226
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
227227
} else {
228-
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
229-
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
228+
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
229+
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
230230
}
231231
}
232232

0 commit comments

Comments
 (0)