@@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
546546// / and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
547547// / [0] = {0, [0, 10)}, {1, [0, 5)}
548548// / [1] = {1, [5, 10)}, {2, [0, 10)}
549+ // / and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
550+ // / [0] = {0, [0, 4)}, {1, [0, 4)}
551+ // / [1] = {2, [0, 4)}, {3, [0, 4)}
552+ // / and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
553+ // / [0] = {0, [0, 4)}
554+ // / [1] = {0, [4, 8)}
555+ // / [2] = {1, [0, 4)}
556+ // / [3] = {1, [4, 8)}
549557struct BitCastBitsEnumerator {
550558 BitCastBitsEnumerator (VectorType sourceVectorType,
551559 VectorType targetVectorType);
@@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
633641// / `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
634642// / the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
635643// / bits extracted from the source vector (i.e. the `shuffle -> and` part).
644+ // /
645+ // /
646+ // / When we consider the above algorithm to rewrite our vector.bitcast, we rely
647+ // / on using dynamic shift amounts for the left and right shifts. This can be
648+ // / inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
649+ // / value. So when possible we can rewrite this as a combination of shifts with
650+ // / a constant splat value and then regroup the selected terms.
651+ // /
652+ // / Eg. Instead of:
653+ // / res = arith.shrui x [0, 4, 8, 0, 4, 8]
654+ // / use:
655+ // / y = arith.shrui x [0, 0, 0, 0, 0, 0] (can be folded away)
656+ // / y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
657+ // / y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
658+ // / y3 = vector.shuffle y y1 [0, 7, 3, 10]
659+ // / res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
660+ // /
661+ // / This is possible when the precomputed shift amounts following a cyclic
662+ // / pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
663+ // / cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
664+ // / of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
665+ // / (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
666+ // / the same 2 steps as above, then it proceeds as follows:
667+ // /
668+ // / 2. for each element in the cycle, x, of the rightShiftAmounts create a
669+ // / shrui with a splat constant of x.
670+ // / 3. repeat 2. with the respective leftShiftAmounts
671+ // / 4. construct a chain of vector.shuffles that will reconstruct the result
672+ // / from the chained shifts
636673struct BitCastRewriter {
637674 // / Helper metadata struct to hold the static quantities for the rewrite.
638675 struct Metadata {
@@ -656,10 +693,25 @@ struct BitCastRewriter {
656693 Value initialValue, Value runningResult,
657694 const BitCastRewriter::Metadata &metadata);
658695
696+ // / Rewrite one step of the sequence when able to use a splat constant for the
697+ // / shiftright and shiftleft.
698+ Value splatRewriteStep (PatternRewriter &rewriter, Location loc,
699+ Value initialValue, Value runningResult,
700+ const BitCastRewriter::Metadata &metadata);
701+
702+ bool useSplatStep (unsigned maxCycleLen) {
703+ return 1 < cycleLen && cycleLen <= maxCycleLen;
704+ }
705+
659706private:
660707 // / Underlying enumerator that encodes the provenance of the bits in the each
661708 // / element of the result vector.
662709 BitCastBitsEnumerator enumerator;
710+
711+ // Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
712+ // denotes that there is a cycle in the precomputed shift amounts and we are
713+ // able to use the splatRewriteStep.
714+ int64_t cycleLen = 0 ;
663715};
664716
665717} // namespace
@@ -775,8 +827,40 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
775827 return success ();
776828}
777829
830+ // Check if the vector is a cycle of the first cycleLen elements.
831+ template <class T >
832+ static bool isCyclic (SmallVector<T> xs, int64_t cycleLen) {
833+ for (int64_t idx = cycleLen, n = xs.size (); idx < n; idx++) {
834+ if (xs[idx] != xs[idx % cycleLen])
835+ return false ;
836+ }
837+ return true ;
838+ }
839+
840+ static SmallVector<int64_t > constructShuffles (int64_t inputSize,
841+ int64_t numCycles,
842+ int64_t cycleLen, int64_t idx) {
843+ // If idx == 1, then the first operand of the shuffle will be the mask which
844+ // will have the original size. So we need to step through the mask with a
845+ // stride of cycleSize.
846+ // When idx > 1, then the first operand will be the size of (idx * cycleSize)
847+ // and so we take the first idx elements of the input and then append the
848+ // strided mask value.
849+ int64_t inputStride = idx == 1 ? cycleLen : idx;
850+
851+ SmallVector<int64_t > shuffles;
852+ for (int64_t cycle = 0 ; cycle < numCycles; cycle++) {
853+ for (int64_t inputIdx = 0 ; inputIdx < idx; inputIdx++) {
854+ shuffles.push_back (cycle * inputStride + inputIdx);
855+ }
856+ shuffles.push_back (inputSize + cycle * cycleLen + idx);
857+ }
858+ return shuffles;
859+ }
860+
778861SmallVector<BitCastRewriter::Metadata>
779862BitCastRewriter::precomputeMetadata (IntegerType shuffledElementType) {
863+ bool cyclicShifts = true ;
780864 SmallVector<BitCastRewriter::Metadata> result;
781865 for (int64_t shuffleIdx = 0 , e = enumerator.getMaxNumberOfEntries ();
782866 shuffleIdx < e; ++shuffleIdx) {
@@ -811,8 +895,71 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
811895 IntegerAttr::get (shuffledElementType, shiftLeft));
812896 }
813897
898+ // Compute a potential cycle size by detecting the number of sourceElements
899+ // at the start of shuffle that are the same
900+ cycleLen = 1 ;
901+ for (int64_t n = shuffles.size (); cycleLen < n; cycleLen++)
902+ if (shuffles[cycleLen] != shuffles[0 ])
903+ break ;
904+
905+ cyclicShifts = cyclicShifts && (cycleLen < (int64_t )shuffles.size ()) &&
906+ isCyclic (shiftRightAmounts, cycleLen) &&
907+ isCyclic (shiftLeftAmounts, cycleLen);
908+
814909 result.push_back ({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
815910 }
911+
912+ cycleLen = cyclicShifts ? cycleLen : 0 ;
913+ return result;
914+ }
915+
916+ Value BitCastRewriter::splatRewriteStep (
917+ PatternRewriter &rewriter, Location loc, Value initialValue,
918+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
919+
920+ // Initial result will be the Shifted Mask which will have the shuffles size.
921+ int64_t inputSize = metadata.shuffles .size ();
922+ int64_t numCycles = inputSize / cycleLen;
923+
924+ auto shuffleOp = rewriter.create <vector::ShuffleOp>(
925+ loc, initialValue, initialValue, metadata.shuffles );
926+
927+ // Intersect with the mask.
928+ VectorType shuffledVectorType = shuffleOp.getResultVectorType ();
929+ auto constOp = rewriter.create <arith::ConstantOp>(
930+ loc, DenseElementsAttr::get (shuffledVectorType, metadata.masks ));
931+ Value andValue = rewriter.create <arith::AndIOp>(loc, shuffleOp, constOp);
932+
933+ Value result;
934+ for (int64_t idx = 0 ; idx < cycleLen; idx++) {
935+ auto shiftRightConstantOp = rewriter.create <arith::ConstantOp>(
936+ loc, SplatElementsAttr::get (shuffledVectorType,
937+ metadata.shiftRightAmounts [idx]));
938+ Value shiftedRight =
939+ rewriter.create <arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
940+
941+ auto shiftLeftConstantOp = rewriter.create <arith::ConstantOp>(
942+ loc, SplatElementsAttr::get (shuffledVectorType,
943+ metadata.shiftLeftAmounts [idx]));
944+ Value shiftedLeft =
945+ rewriter.create <arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
946+
947+ if (result) {
948+ SmallVector<int64_t > shuffles =
949+ constructShuffles (inputSize, numCycles, cycleLen, idx);
950+ result = rewriter.create <vector::ShuffleOp>(loc, result, shiftedLeft,
951+ shuffles);
952+
953+ // After the first shuffle in the chain, the size of the input result will
954+ // grow as we append more shuffles together to reconstruct the
955+ // shuffledVectorType size. Each iteration they will retain numCycles more
956+ // elements.
957+ inputSize = numCycles * (idx + 1 );
958+ } else {
959+ result = shiftedLeft;
960+ }
961+ }
962+
816963 return result;
817964}
818965
@@ -939,6 +1086,11 @@ namespace {
9391086struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9401087 using OpRewritePattern::OpRewritePattern;
9411088
1089+ RewriteBitCastOfTruncI (MLIRContext *context, PatternBenefit benefit,
1090+ unsigned maxCycleLen)
1091+ : OpRewritePattern<vector::BitCastOp>(context, benefit),
1092+ maxCycleLen{maxCycleLen} {}
1093+
9421094 LogicalResult matchAndRewrite (vector::BitCastOp bitCastOp,
9431095 PatternRewriter &rewriter) const override {
9441096 // The source must be a trunc op.
@@ -961,8 +1113,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9611113 Value runningResult;
9621114 for (const BitCastRewriter ::Metadata &metadata :
9631115 bcr.precomputeMetadata (shuffledElementType)) {
964- runningResult = bcr.genericRewriteStep (
965- rewriter, bitCastOp->getLoc (), truncValue, runningResult, metadata);
1116+ runningResult =
1117+ bcr.useSplatStep (maxCycleLen)
1118+ ? bcr.splatRewriteStep (rewriter, bitCastOp->getLoc (), truncValue,
1119+ runningResult, metadata)
1120+ : bcr.genericRewriteStep (rewriter, bitCastOp->getLoc (),
1121+ truncValue, runningResult, metadata);
9661122 }
9671123
9681124 // Finalize the rewrite.
@@ -986,6 +1142,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9861142
9871143 return success ();
9881144 }
1145+
1146+ private:
1147+ unsigned maxCycleLen;
9891148};
9901149} // namespace
9911150
@@ -1001,8 +1160,10 @@ template <typename ExtOpType>
10011160struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10021161 using OpRewritePattern<ExtOpType>::OpRewritePattern;
10031162
1004- RewriteExtOfBitCast (MLIRContext *context, PatternBenefit benefit)
1005- : OpRewritePattern<ExtOpType>(context, benefit) {}
1163+ RewriteExtOfBitCast (MLIRContext *context, PatternBenefit benefit,
1164+ unsigned maxCycleLen)
1165+ : OpRewritePattern<ExtOpType>(context, benefit),
1166+ maxCycleLen{maxCycleLen} {}
10061167
10071168 LogicalResult matchAndRewrite (ExtOpType extOp,
10081169 PatternRewriter &rewriter) const override {
@@ -1026,8 +1187,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10261187 cast<IntegerType>(getElementTypeOrSelf (sourceValue.getType ()));
10271188 for (const BitCastRewriter::Metadata &metadata :
10281189 bcr.precomputeMetadata (shuffledElementType)) {
1029- runningResult = bcr.genericRewriteStep (
1030- rewriter, bitCastOp->getLoc (), sourceValue, runningResult, metadata);
1190+ runningResult =
1191+ bcr.useSplatStep (maxCycleLen)
1192+ ? bcr.splatRewriteStep (rewriter, bitCastOp->getLoc (), sourceValue,
1193+ runningResult, metadata)
1194+ : bcr.genericRewriteStep (rewriter, bitCastOp->getLoc (),
1195+ sourceValue, runningResult, metadata);
10311196 }
10321197
10331198 // Finalize the rewrite.
@@ -1044,6 +1209,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10441209
10451210 return success ();
10461211 }
1212+
1213+ private:
1214+ unsigned maxCycleLen;
10471215};
10481216
10491217// / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
@@ -1222,10 +1390,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
12221390}
12231391
12241392void vector::populateVectorNarrowTypeRewritePatterns (
1225- RewritePatternSet &patterns, PatternBenefit benefit) {
1393+ RewritePatternSet &patterns, PatternBenefit benefit, unsigned maxCycleLen ) {
12261394 patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
12271395 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
1228- benefit);
1396+ benefit, maxCycleLen );
12291397
12301398 // Patterns for aligned cases. We set higher priority as they are expected to
12311399 // generate better performance for aligned cases.
0 commit comments