Skip to content

Commit 69df087

Browse files
committed
Refactors layout definitions for Gmem in Flash kernel traits for improved readability
1 parent 3d88f28 commit 69df087

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

csrc/flash_dmattn/src/kernel_traits.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ struct Flash_fwd_kernel_traits : public Base {
229229
// Accumulator layout for output
230230
using GmemLayoutAtomOaccum = std::conditional_t<
231231
kBlockKSmem == 32,
232-
Layout<Shape <_16, _8>, Stride< _8, _1>>, // Thread layout, 8 threads per row
233-
Layout<Shape <_8, _16>, Stride< _16, _1>> // Thread layout, 16 threads per row
232+
Layout<Shape<_16, _8>, Stride<_8, _1>>, // Thread layout, 8 threads per row
233+
Layout<Shape<_8, _16>, Stride<_16, _1>> // Thread layout, 16 threads per row
234234
>;
235235

236236
using GmemTiledCopyOaccum = decltype(
237237
make_tiled_copy(
238238
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
239239
GmemLayoutAtomOaccum{},
240-
Layout<Shape < _1, _4>>{}
240+
Layout<Shape<_1, _4>>{}
241241
)
242242
); // Val layout, 4 vals per store
243243
};
@@ -442,15 +442,15 @@ struct Flash_bwd_kernel_traits : public Base {
442442
static_assert(kNThreads % kGmemThreadsPerRowMask == 0, "kNThreads must be a multiple of kGmemThreadsPerRowMask");
443443
static_assert(kNThreads % kGmemThreadsPerRowBias == 0, "kNThreads must be a multiple of kGmemThreadsPerRowBias");
444444
using GmemLayoutAtomQKVO = Layout<
445-
Shape <Int<kNThreads / kGmemThreadsPerRowQKVO>, Int<kGmemThreadsPerRowQKVO>>,
445+
Shape<Int<kNThreads / kGmemThreadsPerRowQKVO>, Int<kGmemThreadsPerRowQKVO>>,
446446
Stride<Int<kGmemThreadsPerRowQKVO>, _1>
447447
>;
448448
using GmemLayoutAtomMask = Layout<
449-
Shape <Int<kNThreads / kGmemThreadsPerRowMask>, Int<kGmemThreadsPerRowMask>>,
449+
Shape<Int<kNThreads / kGmemThreadsPerRowMask>, Int<kGmemThreadsPerRowMask>>,
450450
Stride<Int<kGmemThreadsPerRowMask>, _1>
451451
>;
452452
using GmemLayoutAtomBias = Layout<
453-
Shape <Int<kNThreads / kGmemThreadsPerRowBias>, Int<kGmemThreadsPerRowBias>>,
453+
Shape<Int<kNThreads / kGmemThreadsPerRowBias>, Int<kGmemThreadsPerRowBias>>,
454454
Stride<Int<kGmemThreadsPerRowBias>, _1>
455455
>;
456456

@@ -486,7 +486,7 @@ struct Flash_bwd_kernel_traits : public Base {
486486
make_tiled_copy(
487487
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
488488
GmemLayoutAtomQKVO{},
489-
Layout<Shape < _1, _8>>{}
489+
Layout<Shape<_1, _8>>{}
490490
)
491491
); // Val layout, 8 vals per store
492492
using GmemTiledCopydBias = decltype(
@@ -500,35 +500,35 @@ struct Flash_bwd_kernel_traits : public Base {
500500
make_tiled_copy(
501501
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
502502
GmemLayoutAtomQKVO{},
503-
Layout<Shape < _1, _8>>{}
503+
Layout<Shape<_1, _8>>{}
504504
)
505505
); // Val layout, 8 vals per store
506506
using GmemTiledCopydQ = decltype(
507507
make_tiled_copy(
508508
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
509509
GmemLayoutAtomQKVO{},
510-
Layout<Shape < _1, _8>>{}
510+
Layout<Shape<_1, _8>>{}
511511
)
512512
); // Val layout, 8 vals per store
513513
using GmemLayoutAtomdQaccum = std::conditional_t<
514514
kBlockKSmem == 32,
515-
Layout<Shape <_32, _8>, Stride< _8, _1>>, // Thread layout, 8 threads per row
516-
Layout<Shape <_16, _16>, Stride< _16, _1>> // Thread layout, 16 threads per row
515+
Layout<Shape<_32, _8>, Stride<_8, _1>>, // Thread layout, 8 threads per row
516+
Layout<Shape<_16, _16>, Stride<_16, _1>> // Thread layout, 16 threads per row
517517
>;
518518
using GmemTiledCopydQaccum = decltype(
519519
make_tiled_copy(
520520
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
521521
GmemLayoutAtomdQaccum{},
522-
Layout<Shape < _1, _4>>{}
522+
Layout<Shape<_1, _4>>{}
523523
)
524524
); // Val layout, 4 vals per store
525525

526526
using GmemTiledCopydQaccumAtomicAdd = decltype(
527527
make_tiled_copy(
528528
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
529-
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
529+
Layout<Shape<_8, _32>, // Thread layout, 8 threads per row
530530
Stride<_32, _1>>{},
531-
Layout<Shape < _1, _1>>{}
531+
Layout<Shape<_1, _1>>{}
532532
)
533533
); // Val layout, 1 val per store
534534
};

0 commit comments

Comments
 (0)