@@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
546546// / and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
547547// / [0] = {0, [0, 10)}, {1, [0, 5)}
548548// / [1] = {1, [5, 10)}, {2, [0, 10)}
549+ // / and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
550+ // / [0] = {0, [0, 4)}, {1, [0, 4)}
551+ // / [1] = {2, [0, 4)}, {3, [0, 4)}
552+ // / and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
553+ // / [0] = {0, [0, 4)}
554+ // / [1] = {0, [4, 8)}
555+ // / [2] = {1, [0, 4)}
556+ // / [3] = {1, [4, 8)}
549557struct BitCastBitsEnumerator {
550558 BitCastBitsEnumerator (VectorType sourceVectorType,
551559 VectorType targetVectorType);
@@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
633641// / `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
634642// / the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
635643// / bits extracted from the source vector (i.e. the `shuffle -> and` part).
644+ // /
645+ // /
646+ // / When we consider the above algorithm to rewrite our vector.bitcast, we rely
647+ // / on using dynamic shift amounts for the left and right shifts. This can be
648+ // / inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
649+ // / value. So when possible we can rewrite this as a combination of shifts with
650+ // / a constant splat value and then regroup the selected terms.
651+ // /
652+ // / Eg. Instead of:
653+ // / res = arith.shrui x [0, 4, 8, 0, 4, 8]
654+ // / use:
655+ // / y = arith.shrui x [0, 0, 0, 0, 0, 0] (can be folded away)
656+ // / y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
657+ // / y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
658+ // / y3 = vector.shuffle y y1 [0, 7, 3, 10]
659+ // / res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
660+ // /
661+ // / This is possible when the precomputed shift amounts following a cyclic
662+ // / pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
663+ // / cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
664+ // / of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
665+ // / (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
666+ // / the same 2 steps as above, then it proceeds as follows:
667+ // /
668+ // / 2. for each element in the cycle, x, of the rightShiftAmounts create a
669+ // / shrui with a splat constant of x.
670+ // / 3. repeat 2. with the respective leftShiftAmounts
671+ // / 4. construct a chain of vector.shuffles that will reconstruct the result
672+ // / from the chained shifts
636673struct BitCastRewriter {
637674 // / Helper metadata struct to hold the static quantities for the rewrite.
638675 struct Metadata {
@@ -656,10 +693,23 @@ struct BitCastRewriter {
656693 Value initialValue, Value runningResult,
657694 const BitCastRewriter::Metadata &metadata);
658695
696+ // / Rewrite one step of the sequence when able to use a splat constant for the
697+ // / shiftright and shiftleft.
698+ Value splatRewriteStep (PatternRewriter &rewriter, Location loc,
699+ Value initialValue, Value runningResult,
700+ const BitCastRewriter::Metadata &metadata);
701+
702+ bool useSplatStep () { return cycleLen > 1 ; }
703+
659704private:
660705 // / Underlying enumerator that encodes the provenance of the bits in the each
661706 // / element of the result vector.
662707 BitCastBitsEnumerator enumerator;
708+
709+ // Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
710+ // denotes that there is a cycle in the precomputed shift amounts and we are
711+ // able to use the splatRewriteStep.
712+ int64_t cycleLen = 0 ;
663713};
664714
665715} // namespace
@@ -775,8 +825,40 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
775825 return success ();
776826}
777827
828+ // Check if the vector is a cycle of the first cycleLen elements.
829+ template <class T >
830+ static bool isCyclic (SmallVector<T> xs, int64_t cycleLen) {
831+ for (int64_t idx = cycleLen, n = xs.size (); idx < n; idx++) {
832+ if (xs[idx] != xs[idx % cycleLen])
833+ return false ;
834+ }
835+ return true ;
836+ }
837+
838+ static SmallVector<int64_t > constructShuffles (int64_t inputSize,
839+ int64_t numCycles,
840+ int64_t cycleLen, int64_t idx) {
841+ // If idx == 1, then the first operand of the shuffle will be the mask which
842+ // will have the original size. So we need to step through the mask with a
843+ // stride of cycleSize.
844+ // When idx > 1, then the first operand will be the size of (idx * cycleSize)
845+ // and so we take the first idx elements of the input and then append the
846+ // strided mask value.
847+ int64_t inputStride = idx == 1 ? cycleLen : idx;
848+
849+ SmallVector<int64_t > shuffles;
850+ for (int64_t cycle = 0 ; cycle < numCycles; cycle++) {
851+ for (int64_t inputIdx = 0 ; inputIdx < idx; inputIdx++) {
852+ shuffles.push_back (cycle * inputStride + inputIdx);
853+ }
854+ shuffles.push_back (inputSize + cycle * cycleLen + idx);
855+ }
856+ return shuffles;
857+ }
858+
778859SmallVector<BitCastRewriter::Metadata>
779860BitCastRewriter::precomputeMetadata (IntegerType shuffledElementType) {
861+ bool cyclicShifts = true ;
780862 SmallVector<BitCastRewriter::Metadata> result;
781863 for (int64_t shuffleIdx = 0 , e = enumerator.getMaxNumberOfEntries ();
782864 shuffleIdx < e; ++shuffleIdx) {
@@ -811,8 +893,71 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
811893 IntegerAttr::get (shuffledElementType, shiftLeft));
812894 }
813895
896+ // Compute a potential cycle size by detecting the number of sourceElements
897+ // at the start of shuffle that are the same
898+ cycleLen = 1 ;
899+ for (int64_t n = shuffles.size (); cycleLen < n; cycleLen++)
900+ if (shuffles[cycleLen] != shuffles[0 ])
901+ break ;
902+
903+ cyclicShifts = cyclicShifts && (cycleLen < (int64_t )shuffles.size ()) &&
904+ isCyclic (shiftRightAmounts, cycleLen) &&
905+ isCyclic (shiftLeftAmounts, cycleLen);
906+
814907 result.push_back ({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
815908 }
909+
910+ cycleLen = cyclicShifts ? cycleLen : 0 ;
911+ return result;
912+ }
913+
914+ Value BitCastRewriter::splatRewriteStep (
915+ PatternRewriter &rewriter, Location loc, Value initialValue,
916+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
917+
918+ // Initial result will be the Shifted Mask which will have the shuffles size.
919+ int64_t inputSize = metadata.shuffles .size ();
920+ int64_t numCycles = inputSize / cycleLen;
921+
922+ auto shuffleOp = rewriter.create <vector::ShuffleOp>(
923+ loc, initialValue, initialValue, metadata.shuffles );
924+
925+ // Intersect with the mask.
926+ VectorType shuffledVectorType = shuffleOp.getResultVectorType ();
927+ auto constOp = rewriter.create <arith::ConstantOp>(
928+ loc, DenseElementsAttr::get (shuffledVectorType, metadata.masks ));
929+ Value andValue = rewriter.create <arith::AndIOp>(loc, shuffleOp, constOp);
930+
931+ Value result;
932+ for (int64_t idx = 0 ; idx < cycleLen; idx++) {
933+ auto shiftRightConstantOp = rewriter.create <arith::ConstantOp>(
934+ loc, SplatElementsAttr::get (shuffledVectorType,
935+ metadata.shiftRightAmounts [idx]));
936+ Value shiftedRight =
937+ rewriter.create <arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
938+
939+ auto shiftLeftConstantOp = rewriter.create <arith::ConstantOp>(
940+ loc, SplatElementsAttr::get (shuffledVectorType,
941+ metadata.shiftLeftAmounts [idx]));
942+ Value shiftedLeft =
943+ rewriter.create <arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
944+
945+ if (result) {
946+ SmallVector<int64_t > shuffles =
947+ constructShuffles (inputSize, numCycles, cycleLen, idx);
948+ result = rewriter.create <vector::ShuffleOp>(loc, result, shiftedLeft,
949+ shuffles);
950+
951+ // After the first shuffle in the chain, the size of the input result will
952+ // grow as we append more shuffles together to reconstruct the
953+ // shuffledVectorType size. Each iteration they will retain numCycles more
954+ // elements.
955+ inputSize = numCycles * (idx + 1 );
956+ } else {
957+ result = shiftedLeft;
958+ }
959+ }
960+
816961 return result;
817962}
818963
@@ -961,8 +1106,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9611106 Value runningResult;
9621107 for (const BitCastRewriter ::Metadata &metadata :
9631108 bcr.precomputeMetadata (shuffledElementType)) {
964- runningResult = bcr.genericRewriteStep (
965- rewriter, bitCastOp->getLoc (), truncValue, runningResult, metadata);
1109+ runningResult =
1110+ bcr.useSplatStep ()
1111+ ? bcr.splatRewriteStep (rewriter, bitCastOp->getLoc (), truncValue,
1112+ runningResult, metadata)
1113+ : bcr.genericRewriteStep (rewriter, bitCastOp->getLoc (),
1114+ truncValue, runningResult, metadata);
9661115 }
9671116
9681117 // Finalize the rewrite.
@@ -1026,8 +1175,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10261175 cast<IntegerType>(getElementTypeOrSelf (sourceValue.getType ()));
10271176 for (const BitCastRewriter::Metadata &metadata :
10281177 bcr.precomputeMetadata (shuffledElementType)) {
1029- runningResult = bcr.genericRewriteStep (
1030- rewriter, bitCastOp->getLoc (), sourceValue, runningResult, metadata);
1178+ runningResult =
1179+ bcr.useSplatStep ()
1180+ ? bcr.splatRewriteStep (rewriter, bitCastOp->getLoc (), sourceValue,
1181+ runningResult, metadata)
1182+ : bcr.genericRewriteStep (rewriter, bitCastOp->getLoc (),
1183+ sourceValue, runningResult, metadata);
10311184 }
10321185
10331186 // Finalize the rewrite.
0 commit comments