@@ -446,6 +446,57 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446 : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
447447 }
448448
449+ // Return a vector such as:
450+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
451+ // [registerSize / 2, 0]]
452+ static std::vector<std::vector<int32_t >>
453+ buildSubGroupTransposeRegisterBases (int32_t registerSize, int32_t laneSize) {
454+ std::vector<std::vector<int32_t >> bases;
455+ std::vector<int32_t > curr (2 );
456+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
457+ curr[1 ] = i;
458+ bases.push_back (curr);
459+ }
460+ curr[1 ] = 0 ;
461+ for (int32_t i = laneSize; i < registerSize; i *= 2 ) {
462+ curr[0 ] = i;
463+ bases.push_back (curr);
464+ }
465+ return bases;
466+ }
467+
468+ // Return a vector such as:
469+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
470+ // [registerSize / (2 * laneSize), 0]]
471+ static std::vector<std::vector<int32_t >>
472+ buildSubGroupShuffleRegisterBases (int32_t registerSize, int32_t laneSize) {
473+ std::vector<std::vector<int32_t >> bases;
474+ std::vector<int32_t > curr (2 );
475+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
476+ curr[1 ] = i;
477+ bases.push_back (curr);
478+ }
479+ curr[1 ] = 0 ;
480+ for (int32_t i = laneSize, val = 1 ; i < registerSize; i *= 2 , val *= 2 ) {
481+ curr[0 ] = val;
482+ bases.push_back (curr);
483+ }
484+ return bases;
485+ }
486+
487+ // Return a vector such as:
488+ // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]]
489+ static std::vector<std::vector<int32_t >>
490+ buildSubGroupTransposeLaneBases (int32_t laneSize) {
491+ std::vector<std::vector<int32_t >> bases;
492+ std::vector<int32_t > curr (2 );
493+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
494+ curr[0 ] = i;
495+ bases.push_back (curr);
496+ }
497+ return bases;
498+ }
499+
449500 bool isSubGroupTranspose (const LinearLayout &srcLayout,
450501 const LinearLayout &dstLayout) const {
451502 MLIRContext *ctx = srcLayout.getInDimNames ().begin ()->getContext ();
@@ -477,35 +528,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
477528 // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
478529 //
479530 // With N >= M.
480- const auto buildBasis = [&](int32_t size, std::size_t index) {
481- std::vector<std::vector<int32_t >> basis;
482- std::vector<int32_t > curr (2 );
483- for (int32_t i = 1 ; i < size; i *= 2 ) {
484- curr[index] = i;
485- basis.push_back (curr);
486- }
487- return basis;
488- };
489- constexpr std::size_t laneIndex = 0 ;
490- constexpr std::size_t registerIndex = 1 ;
491- int32_t laneSize = conversion->getInDimSize (kLane );
492- std::vector<std::vector<int32_t >> registerBases =
493- buildBasis (laneSize, registerIndex);
494- {
495- // Populate register bases for N > M.
496- std::vector<int32_t > base (2 );
497- for (int32_t i = laneSize,
498- registerSize = conversion->getInDimSize (kRegister );
499- i < registerSize; i *= 2 ) {
500- base[laneIndex] = i;
501- registerBases.push_back (base);
502- }
503- }
504- std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
505- bases{{{kRegister , std::move (registerBases)},
506- {kLane , buildBasis (laneSize, laneIndex)}}};
507- std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
508- return conversion == LinearLayout (bases, outDimNames);
531+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
532+ int32_t laneInDimSize = conversion->getInDimSize (kLane );
533+ return conversion->getBases ().lookup (kRegister ) ==
534+ buildSubGroupTransposeRegisterBases (registerInDimSize,
535+ laneInDimSize) &&
536+ conversion->getBases ().lookup (kLane ) ==
537+ buildSubGroupTransposeLaneBases (laneInDimSize);
509538 }
510539
511540 LogicalResult
@@ -612,39 +641,29 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
612641
613642 LinearLayout comp = dstLayout.invertAndCompose (srcLayout);
614643 std::optional<LinearLayout> conversion = comp.divideRight (
644+ LinearLayout::zeros1D (comp.getInDimSize (kLane ), kLane , kLane ) *
615645 LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
616646 LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
617- assert (conversion && " Expecting valid conversion" );
647+ if (!conversion)
648+ return false ;
618649 // TODO: Support more kind of shuffles.
619650 // Expected conversion is:
620651 // - register=1 -> (0, 1)
621652 // ...
622- // register=i -> (0, i)
653+ // - register=2** i -> (0, 2** i)
623654 // ...
624- // register=N -> (0, N)
625- // - lane=1 -> (0, 0)
655+ // - register=M -> (0, 2**M)
626656 // ...
627- // lane=i -> (0 , 0)
657+ // - register=2**k -> (2**(k-M) , 0)
628658 // ...
629- // lane=N -> (0, 0)
630- // where out dims are: [register (size 1), lane (size N)]
631- std::vector<std::vector<int32_t >> registerBases;
632- {
633- constexpr std::size_t registerIndex = 1 ;
634- std::vector<int32_t > base (2 );
635- for (int32_t i = 1 , n = conversion->getInDimSize (kLane ); i < n; i *= 2 ) {
636- base[registerIndex] = i;
637- registerBases.push_back (base);
638- }
639- }
640-
641- std::vector<std::vector<int32_t >> laneBases (
642- conversion->getInDimSizeLog2 (kLane ), std::vector<int32_t >{0 , 0 });
643- std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
644- bases{{{kRegister , std::move (registerBases)},
645- {kLane , std::move (laneBases)}}};
646- std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
647- return conversion == LinearLayout (bases, outDimNames);
659+ // - register=2**N -> (2**(N-M), 0)
660+ // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
661+ //
662+ // With N >= M.
663+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
664+ int32_t laneOutDimSize = conversion->getOutDimSize (kLane );
665+ return conversion->getBases ().lookup (kRegister ) ==
666+ buildSubGroupShuffleRegisterBases (registerInDimSize, laneOutDimSize);
648667 }
649668
650669 bool isSupportedSubGroupShuffle (ConvertLayoutOp, OpAdaptor) const {
@@ -677,7 +696,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
677696
678697 SmallVector<Value> inVals =
679698 unpackLLElements (loc, adaptor.getSrc (), rewriter);
680- assert (inVals.size () == 1 && " Expecting single element" );
681699
682700 // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
683701 // upstream level. We are not enabling support for all types here as that
@@ -706,7 +724,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
706724 });
707725
708726 SmallVector<Value> outVals =
709- performSubGroupShuffle (loc, inVals. front () , subGroupSize, rewriter);
727+ performSubGroupShuffle (loc, inVals, subGroupSize, rewriter);
710728
711729 // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
712730 // upstream level. We are not enabling support for all types here as that
@@ -737,16 +755,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
737755 }
738756
739757 SmallVector<Value>
740- performSubGroupShuffle (Location loc, Value val, int32_t subGroupSize,
758+ performSubGroupShuffle (Location loc, ArrayRef<Value> inVals,
759+ int32_t subGroupSize,
741760 ConversionPatternRewriter &rewriter) const {
742761 SmallVector<Value> res;
743762 Value width = i32_val (subGroupSize);
744- for (int32_t i = 0 ; i < subGroupSize; ++i)
745- res.push_back (
746- rewriter
747- .create <mlir::gpu::ShuffleOp>(loc, val, i32_val (i), width,
748- mlir::gpu::ShuffleMode::IDX)
749- .getShuffleResult ());
763+ for (Value val : inVals) {
764+ for (int32_t i = 0 ; i < subGroupSize; ++i)
765+ res.push_back (
766+ rewriter
767+ .create <mlir::gpu::ShuffleOp>(loc, val, i32_val (i), width,
768+ mlir::gpu::ShuffleMode::IDX)
769+ .getShuffleResult ());
770+ }
750771 return res;
751772 }
752773
0 commit comments