@@ -446,6 +446,57 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446 : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
447447 }
448448
449+ // Return a vector such as:
450+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
451+ // [registerSize / 2, 0]]
452+ static std::vector<std::vector<int32_t >>
453+ buildSubGroupTransposeRegisterBases (int32_t registerSize, int32_t laneSize) {
454+ std::vector<std::vector<int32_t >> bases;
455+ std::vector<int32_t > curr (2 );
456+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
457+ curr[1 ] = i;
458+ bases.push_back (curr);
459+ }
460+ curr[1 ] = 0 ;
461+ for (int32_t i = laneSize; i < registerSize; i *= 2 ) {
462+ curr[0 ] = i;
463+ bases.push_back (curr);
464+ }
465+ return bases;
466+ }
467+
468+ // Return a vector such as:
469+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
470+ // [registerSize / (2 * laneSize), 0]]
471+ static std::vector<std::vector<int32_t >>
472+ buildSubGroupShuffleRegisterBases (int32_t registerSize, int32_t laneSize) {
473+ std::vector<std::vector<int32_t >> bases;
474+ std::vector<int32_t > curr (2 );
475+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
476+ curr[1 ] = i;
477+ bases.push_back (curr);
478+ }
479+ curr[1 ] = 0 ;
480+ for (int32_t i = laneSize, val = 1 ; i < registerSize; i *= 2 , val *= 2 ) {
481+ curr[0 ] = val;
482+ bases.push_back (curr);
483+ }
484+ return bases;
485+ }
486+
487+ // Return a vector such as:
488+ // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]]
489+ static std::vector<std::vector<int32_t >>
490+ buildSubGroupTransposeLaneBases (int32_t laneSize) {
491+ std::vector<std::vector<int32_t >> bases;
492+ std::vector<int32_t > curr (2 );
493+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
494+ curr[0 ] = i;
495+ bases.push_back (curr);
496+ }
497+ return bases;
498+ }
499+
449500 bool isSubGroupTranspose (const LinearLayout &srcLayout,
450501 const LinearLayout &dstLayout) const {
451502 MLIRContext *ctx = srcLayout.getInDimNames ().begin ()->getContext ();
@@ -476,35 +527,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
476527 // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
477528 //
478529 // With N >= M.
479- const auto buildBasis = [&](int32_t size, std::size_t index) {
480- std::vector<std::vector<int32_t >> basis;
481- std::vector<int32_t > curr (2 );
482- for (int32_t i = 1 ; i < size; i *= 2 ) {
483- curr[index] = i;
484- basis.push_back (curr);
485- }
486- return basis;
487- };
488- constexpr std::size_t laneIndex = 0 ;
489- constexpr std::size_t registerIndex = 1 ;
490- int32_t laneSize = conversion->getInDimSize (kLane );
491- std::vector<std::vector<int32_t >> registerBases =
492- buildBasis (laneSize, registerIndex);
493- {
494- // Populate register bases for N > M.
495- std::vector<int32_t > base (2 );
496- for (int32_t i = laneSize,
497- registerSize = conversion->getInDimSize (kRegister );
498- i < registerSize; i *= 2 ) {
499- base[laneIndex] = i;
500- registerBases.push_back (base);
501- }
502- }
503- std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
504- bases{{{kRegister , std::move (registerBases)},
505- {kLane , buildBasis (laneSize, laneIndex)}}};
506- std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
507- return conversion == LinearLayout (bases, outDimNames);
530+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
531+ int32_t laneInDimSize = conversion->getInDimSize (kLane );
532+ return conversion->getBases ().lookup (kRegister ) ==
533+ buildSubGroupTransposeRegisterBases (registerInDimSize,
534+ laneInDimSize) &&
535+ conversion->getBases ().lookup (kLane ) ==
536+ buildSubGroupTransposeLaneBases (laneInDimSize);
508537 }
509538
510539 LogicalResult
@@ -619,32 +648,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
619648 // Expected conversion is:
620649 // - register=1 -> (0, 1)
621650 // ...
622- // register=i -> (0, i)
651+ // - register=2** i -> (0, 2** i)
623652 // ...
624- // register=N -> (0, N)
625- // - lane=1 -> (0, 0)
653+ // - register=M -> (0, 2**M)
626654 // ...
627- // lane=i -> (0 , 0)
655+ // - register=2**k -> (2**(k-M) , 0)
628656 // ...
629- // lane=N -> (0, 0)
630- // where out dims are: [register (size 1), lane (size N)]
631- std::vector<std::vector<int32_t >> registerBases;
632- {
633- constexpr std::size_t registerIndex = 1 ;
634- std::vector<int32_t > base (2 );
635- for (int32_t i = 1 , n = conversion->getInDimSize (kLane ); i < n; i *= 2 ) {
636- base[registerIndex] = i;
637- registerBases.push_back (base);
638- }
639- }
640-
641- std::vector<std::vector<int32_t >> laneBases (
642- conversion->getInDimSizeLog2 (kLane ), std::vector<int32_t >{0 , 0 });
643- std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
644- bases{{{kRegister , std::move (registerBases)},
645- {kLane , std::move (laneBases)}}};
646- std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
647- return conversion == LinearLayout (bases, outDimNames);
657+ // - register=2**N -> (2**(N-M), 0)
658+ // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
659+ //
660+ // With N >= M.
661+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
662+ int32_t laneOutDimSize = conversion->getOutDimSize (kLane );
663+ return conversion->getBases ().lookup (kRegister ) ==
664+ buildSubGroupShuffleRegisterBases (registerInDimSize, laneOutDimSize);
648665 }
649666
650667 bool isSupportedSubGroupShuffle (ConvertLayoutOp, OpAdaptor) const {
@@ -674,7 +691,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
674691
675692 SmallVector<Value> inVals =
676693 unpackLLElements (loc, adaptor.getSrc (), rewriter);
677- assert (inVals.size () == 1 && " Expecting single element" );
678694
679695 // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
680696 // upstream level. We are not enabling support for all types here as that
@@ -703,7 +719,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
703719 });
704720
705721 SmallVector<Value> outVals =
706- performSubGroupShuffle (loc, inVals. front () , subGroupSize, rewriter);
722+ performSubGroupShuffle (loc, inVals, subGroupSize, rewriter);
707723
708724 // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
709725 // upstream level. We are not enabling support for all types here as that
@@ -734,16 +750,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
734750 }
735751
736752 SmallVector<Value>
737- performSubGroupShuffle (Location loc, Value val, int32_t subGroupSize,
753+ performSubGroupShuffle (Location loc, ArrayRef<Value> inVals,
754+ int32_t subGroupSize,
738755 ConversionPatternRewriter &rewriter) const {
739756 SmallVector<Value> res;
740757 Value width = i32_val (subGroupSize);
741- for (int32_t i = 0 ; i < subGroupSize; ++i)
742- res.push_back (
743- rewriter
744- .create <mlir::gpu::ShuffleOp>(loc, val, i32_val (i), width,
745- mlir::gpu::ShuffleMode::IDX)
746- .getShuffleResult ());
758+ for (Value val : inVals) {
759+ for (int32_t i = 0 ; i < subGroupSize; ++i)
760+ res.push_back (
761+ rewriter
762+ .create <mlir::gpu::ShuffleOp>(loc, val, i32_val (i), width,
763+ mlir::gpu::ShuffleMode::IDX)
764+ .getShuffleResult ());
765+ }
747766 return res;
748767 }
749768
0 commit comments