Skip to content

Commit c403597

Browse files
committed
Refactor mask processor initialization to remove causal parameter in compute_attn functions
1 parent 03b24a3 commit c403597

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

csrc/flash_dmattn/src/flash_fwd_kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
395395

396396
FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
397397

398-
// Init dynamic mask processor
399-
FLASH_NAMESPACE::Mask<Is_causal> mask(
398+
// Init mask processor
399+
FLASH_NAMESPACE::Mask mask(
400400
binfo.actual_seqlen_k, binfo.actual_seqlen_q
401401
);
402402

@@ -1044,8 +1044,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
10441044

10451045
FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
10461046

1047-
// Init dynamic mask processor
1048-
FLASH_NAMESPACE::Mask<Is_causal> mask(
1047+
// Init mask processor
1048+
FLASH_NAMESPACE::Mask mask(
10491049
binfo.actual_seqlen_k, binfo.actual_seqlen_q
10501050
);
10511051

0 commit comments

Comments
 (0)