Skip to content

Commit e9f9fcc

Browse files
committed
Separates mask and bias copy operations into distinct types
Splits the combined GmemTiledCopyMaskBias type into separate GmemTiledCopyMask and GmemTiledCopyBias types in both forward and backward kernel traits. This separation improves code clarity and allows for independent handling of mask and bias copy operations, enabling more flexible memory access patterns and potential optimizations.
1 parent 3b7b57b commit e9f9fcc

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

csrc/flash_dmattn/src/kernel_traits.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,14 @@ struct Flash_fwd_kernel_traits : public Base {
183183
Layout<Shape<_1, _8>>{}
184184
)
185185
); // Val layout, 8 vals per read
186-
using GmemTiledCopyMaskBias = decltype(
186+
using GmemTiledCopyMask = decltype(
187+
make_tiled_copy(
188+
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
189+
GmemLayoutAtom{},
190+
Layout<Shape<_1, _8>>{}
191+
)
192+
); // Val layout, 8 vals per read
193+
using GmemTiledCopyBias = decltype(
187194
make_tiled_copy(
188195
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
189196
GmemLayoutAtom{},
@@ -442,7 +449,14 @@ struct Flash_bwd_kernel_traits : public Base {
442449
Layout<Shape < _1, _8>>{}
443450
)
444451
); // Val layout, 8 vals per store
445-
using GmemTiledCopyMaskBias = decltype(
452+
using GmemTiledCopyMask = decltype(
453+
make_tiled_copy(
454+
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
455+
GmemLayoutAtom{},
456+
Layout<Shape<_1, _8>>{}
457+
)
458+
); // Val layout, 8 vals per read
459+
using GmemTiledCopyBias = decltype(
446460
make_tiled_copy(
447461
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
448462
GmemLayoutAtom{},

0 commit comments

Comments
 (0)