Skip to content

Commit 510aaf5

Browse files
committed
Adds mask reduction utility function
Implements a device function that performs logical OR reduction across mask tensor elements and synchronizes the result across thread blocks using warp-level primitives. Enables efficient sparse attention pattern processing by allowing threads to collectively determine if any mask elements are active within a given region.
1 parent e23b08f commit 510aaf5

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

csrc/flash_dmattn/src/utils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,25 @@ __forceinline__ __device__ void sparse_gemm_rs(
333333
}
334334
}
335335

336+
////////////////////////////////////////////////////////////////////////////////////////////////////
337+
338+
339+
template <typename Tensor, typename ThrCopy>
340+
__forceinline__ __device__ void mask_or_reduce(
341+
Tensor &tSsMask,
342+
bool &active,
343+
ThrCopy smem_thr_copy_Mask
344+
) {
345+
Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_D(tSsMask);
346+
bool active_local = false;
347+
#pragma unroll
348+
for (int i = 0; i < size(tSsMask_copy_view); ++i) {
349+
active_local |= tSsMask_copy_view(i);
350+
}
351+
active = __syncthreads_or(active_local);
352+
}
353+
354+
336355
////////////////////////////////////////////////////////////////////////////////////////////////////
337356

338357
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))

0 commit comments

Comments
 (0)