@@ -446,6 +446,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446 : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
447447 }
448448
449+ // Return a vector such as:
450+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
451+ // [registerSize / 2, 0]],
452+ // i.e., mapping registers to lanes till laneSize and performing an ID
453+ // conversion afterwards.
454+ static std::vector<std::vector<int32_t >>
455+ buildSubGroupTransposeRegisterBases (int32_t registerSize, int32_t laneSize) {
456+ std::vector<std::vector<int32_t >> bases;
457+ std::vector<int32_t > curr (2 );
458+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
459+ curr[1 ] = i;
460+ bases.push_back (curr);
461+ }
462+ curr[1 ] = 0 ;
463+ for (int32_t i = laneSize; i < registerSize; i *= 2 ) {
464+ curr[0 ] = i;
465+ bases.push_back (curr);
466+ }
467+ return bases;
468+ }
469+
470+ // Return a vector such as:
471+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
472+ // [registerSize / (2 * laneSize), 0]]
473+ // i.e., mapping registers to lanes till laneSize and repeating the pattern
474+ // afterwards.
475+ static std::vector<std::vector<int32_t >>
476+ buildSubGroupShuffleRegisterBases (int32_t registerSize, int32_t laneSize) {
477+ std::vector<std::vector<int32_t >> bases;
478+ std::vector<int32_t > curr (2 );
479+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
480+ curr[1 ] = i;
481+ bases.push_back (curr);
482+ }
483+ curr[1 ] = 0 ;
484+ for (int32_t i = laneSize, val = 1 ; i < registerSize; i *= 2 , val *= 2 ) {
485+ curr[0 ] = val;
486+ bases.push_back (curr);
487+ }
488+ return bases;
489+ }
490+
491+ // Return a vector such as:
492+ // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
493+ // i.e., mapping lanes to registers.
494+ static std::vector<std::vector<int32_t >>
495+ buildSubGroupTransposeLaneBases (int32_t laneSize) {
496+ std::vector<std::vector<int32_t >> bases;
497+ std::vector<int32_t > curr (2 );
498+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
499+ curr[0 ] = i;
500+ bases.push_back (curr);
501+ }
502+ return bases;
503+ }
504+
449505 bool isSubGroupTranspose (const LinearLayout &srcLayout,
450506 const LinearLayout &dstLayout) const {
451507 MLIRContext *ctx = srcLayout.getInDimNames ().begin ()->getContext ();
@@ -476,35 +532,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
476532 // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
477533 //
478534 // With N >= M.
479- const auto buildBasis = [&](int32_t size, std::size_t index) {
480- std::vector<std::vector<int32_t >> basis;
481- std::vector<int32_t > curr (2 );
482- for (int32_t i = 1 ; i < size; i *= 2 ) {
483- curr[index] = i;
484- basis.push_back (curr);
485- }
486- return basis;
487- };
488- constexpr std::size_t laneIndex = 0 ;
489- constexpr std::size_t registerIndex = 1 ;
490- int32_t laneSize = conversion->getInDimSize (kLane );
491- std::vector<std::vector<int32_t >> registerBases =
492- buildBasis (laneSize, registerIndex);
493- {
494- // Populate register bases for N > M.
495- std::vector<int32_t > base (2 );
496- for (int32_t i = laneSize,
497- registerSize = conversion->getInDimSize (kRegister );
498- i < registerSize; i *= 2 ) {
499- base[laneIndex] = i;
500- registerBases.push_back (base);
501- }
502- }
503- std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
504- bases{{{kRegister , std::move (registerBases)},
505- {kLane , buildBasis (laneSize, laneIndex)}}};
506- std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
507- return conversion == LinearLayout (bases, outDimNames);
535+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
536+ int32_t laneInDimSize = conversion->getInDimSize (kLane );
537+ return conversion->getBases ().lookup (kRegister ) ==
538+ buildSubGroupTransposeRegisterBases (registerInDimSize,
539+ laneInDimSize) &&
540+ conversion->getBases ().lookup (kLane ) ==
541+ buildSubGroupTransposeLaneBases (laneInDimSize);
508542 }
509543
510544 LogicalResult
@@ -619,32 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
619653 // Expected conversion is:
620654 // - register=1 -> (0, 1)
621655 // ...
622- // register=i -> (0, i)
656+ // - register=2**i -> (0, 2**i)
657+ // ...
658+ // - register=M -> (0, 2**M)
659+ // ...
660+ // - register=2**k -> (2**(k-M), 0)
623661 // ...
624- // register=N -> (0, N )
662+ // - register=2** N -> (2**(N-M), 0 )
625663 // - lane=1 -> (0, 0)
626664 // ...
627- // lane=i -> (0, 0)
665+ // - lane=2**j -> (0, 0)
628666 // ...
629- // lane=N -> (0, 0)
630- // where out dims are: [register (size 1), lane (size N)]
631- std::vector<std::vector<int32_t >> registerBases;
632- {
633- constexpr std::size_t registerIndex = 1 ;
634- std::vector<int32_t > base (2 );
635- for (int32_t i = 1 , n = conversion->getInDimSize (kLane ); i < n; i *= 2 ) {
636- base[registerIndex] = i;
637- registerBases.push_back (base);
638- }
639- }
640-
641- std::vector<std::vector<int32_t >> laneBases (
642- conversion->getInDimSizeLog2 (kLane ), std::vector<int32_t >{0 , 0 });
643- std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
644- bases{{{kRegister , std::move (registerBases)},
645- {kLane , std::move (laneBases)}}};
646- std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
647- return conversion == LinearLayout (bases, outDimNames);
667+ // lane=2**M -> (0, 0)
668+ // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
669+ //
670+ // With N >= M.
671+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
672+ int32_t laneOutDimSize = conversion->getOutDimSize (kLane );
673+ return conversion->sublayoutIsZero ({kLane }, {kRegister , kLane }) &&
674+ conversion->getBases ().lookup (kRegister ) ==
675+ buildSubGroupShuffleRegisterBases (registerInDimSize,
676+ laneOutDimSize);
648677 }
649678
650679 bool isSupportedSubGroupShuffle (ConvertLayoutOp, OpAdaptor) const {
@@ -674,7 +703,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
674703
675704 SmallVector<Value> inVals =
676705 unpackLLElements (loc, adaptor.getSrc (), rewriter);
677- assert (inVals.size () == 1 && " Expecting single element" );
678706
679707 // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
680708 // upstream level. We are not enabling support for all types here as that
@@ -703,7 +731,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
703731 });
704732
705733 SmallVector<Value> outVals =
706- performSubGroupShuffle (loc, inVals. front () , subGroupSize, rewriter);
734+ performSubGroupShuffle (loc, inVals, subGroupSize, rewriter);
707735
708736 // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
709737 // upstream level. We are not enabling support for all types here as that
@@ -734,16 +762,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
734762 }
735763
736764 SmallVector<Value>
737- performSubGroupShuffle (Location loc, Value val, int32_t subGroupSize,
765+ performSubGroupShuffle (Location loc, ArrayRef<Value> inVals,
766+ int32_t subGroupSize,
738767 ConversionPatternRewriter &rewriter) const {
739768 SmallVector<Value> res;
740769 Value width = i32_val (subGroupSize);
741- for (int32_t i = 0 ; i < subGroupSize; ++i)
742- res.push_back (
743- rewriter
744- .create <mlir::gpu::ShuffleOp>(loc, val, i32_val (i), width,
745- mlir::gpu::ShuffleMode::IDX)
746- .getShuffleResult ());
770+ for (Value val : inVals) {
771+ for (int32_t i = 0 ; i < subGroupSize; ++i)
772+ res.push_back (
773+ rewriter
774+ .create <mlir::gpu::ShuffleOp>(loc, val, i32_val (i), width,
775+ mlir::gpu::ShuffleMode::IDX)
776+ .getShuffleResult ());
777+ }
747778 return res;
748779 }
749780
0 commit comments