@@ -229,15 +229,15 @@ struct Flash_fwd_kernel_traits : public Base {
229229 // Accumulator layout for output
230230 using GmemLayoutAtomOaccum = std::conditional_t <
231231 kBlockKSmem == 32 ,
232- Layout<Shape <_16, _8>, Stride< _8, _1>>, // Thread layout, 8 threads per row
233- Layout<Shape <_8, _16>, Stride< _16, _1>> // Thread layout, 16 threads per row
232+ Layout<Shape<_16, _8>, Stride<_8, _1>>, // Thread layout, 8 threads per row
233+ Layout<Shape<_8, _16>, Stride<_16, _1>> // Thread layout, 16 threads per row
234234 >;
235235
236236 using GmemTiledCopyOaccum = decltype (
237237 make_tiled_copy (
238238 Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128 >, ElementAccum>{},
239239 GmemLayoutAtomOaccum{},
240- Layout<Shape < _1, _4>>{}
240+ Layout<Shape< _1, _4>>{}
241241 )
242242 ); // Val layout, 4 vals per store
243243};
@@ -442,15 +442,15 @@ struct Flash_bwd_kernel_traits : public Base {
442442 static_assert (kNThreads % kGmemThreadsPerRowMask == 0 , " kNThreads must be a multiple of kGmemThreadsPerRowMask" );
443443 static_assert (kNThreads % kGmemThreadsPerRowBias == 0 , " kNThreads must be a multiple of kGmemThreadsPerRowBias" );
444444 using GmemLayoutAtomQKVO = Layout<
445- Shape <Int<kNThreads / kGmemThreadsPerRowQKVO >, Int<kGmemThreadsPerRowQKVO >>,
445+ Shape<Int<kNThreads / kGmemThreadsPerRowQKVO >, Int<kGmemThreadsPerRowQKVO >>,
446446 Stride<Int<kGmemThreadsPerRowQKVO >, _1>
447447 >;
448448 using GmemLayoutAtomMask = Layout<
449- Shape <Int<kNThreads / kGmemThreadsPerRowMask >, Int<kGmemThreadsPerRowMask >>,
449+ Shape<Int<kNThreads / kGmemThreadsPerRowMask >, Int<kGmemThreadsPerRowMask >>,
450450 Stride<Int<kGmemThreadsPerRowMask >, _1>
451451 >;
452452 using GmemLayoutAtomBias = Layout<
453- Shape <Int<kNThreads / kGmemThreadsPerRowBias >, Int<kGmemThreadsPerRowBias >>,
453+ Shape<Int<kNThreads / kGmemThreadsPerRowBias >, Int<kGmemThreadsPerRowBias >>,
454454 Stride<Int<kGmemThreadsPerRowBias >, _1>
455455 >;
456456
@@ -486,7 +486,7 @@ struct Flash_bwd_kernel_traits : public Base {
486486 make_tiled_copy (
487487 Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128 >, elem_type>{},
488488 GmemLayoutAtomQKVO{},
489- Layout<Shape < _1, _8>>{}
489+ Layout<Shape< _1, _8>>{}
490490 )
491491 ); // Val layout, 8 vals per store
492492 using GmemTiledCopydBias = decltype (
@@ -500,35 +500,35 @@ struct Flash_bwd_kernel_traits : public Base {
500500 make_tiled_copy (
501501 Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128 >, elem_type>{},
502502 GmemLayoutAtomQKVO{},
503- Layout<Shape < _1, _8>>{}
503+ Layout<Shape< _1, _8>>{}
504504 )
505505 ); // Val layout, 8 vals per store
506506 using GmemTiledCopydQ = decltype (
507507 make_tiled_copy (
508508 Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128 >, elem_type>{},
509509 GmemLayoutAtomQKVO{},
510- Layout<Shape < _1, _8>>{}
510+ Layout<Shape< _1, _8>>{}
511511 )
512512 ); // Val layout, 8 vals per store
513513 using GmemLayoutAtomdQaccum = std::conditional_t <
514514 kBlockKSmem == 32 ,
515- Layout<Shape <_32, _8>, Stride< _8, _1>>, // Thread layout, 8 threads per row
516- Layout<Shape <_16, _16>, Stride< _16, _1>> // Thread layout, 16 threads per row
515+ Layout<Shape<_32, _8>, Stride<_8, _1>>, // Thread layout, 8 threads per row
516+ Layout<Shape<_16, _16>, Stride<_16, _1>> // Thread layout, 16 threads per row
517517 >;
518518 using GmemTiledCopydQaccum = decltype (
519519 make_tiled_copy (
520520 Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128 >, ElementAccum>{},
521521 GmemLayoutAtomdQaccum{},
522- Layout<Shape < _1, _4>>{}
522+ Layout<Shape< _1, _4>>{}
523523 )
524524 ); // Val layout, 4 vals per store
525525
526526 using GmemTiledCopydQaccumAtomicAdd = decltype (
527527 make_tiled_copy (
528528 Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128 >, ElementAccum>{},
529- Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
529+ Layout<Shape<_8, _32>, // Thread layout, 8 threads per row
530530 Stride<_32, _1>>{},
531- Layout<Shape < _1, _1>>{}
531+ Layout<Shape< _1, _1>>{}
532532 )
533533 ); // Val layout, 1 val per store
534534};
0 commit comments