Skip to content

Commit e29a6da

Browse files
committed
[WIP] using splat shifts
1 parent 07a1fbe commit e29a6da

File tree

6 files changed

+325
-13
lines changed

6 files changed

+325
-13
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
318318
Warning: these patterns currently only work for little endian targets.
319319
}];
320320

321+
let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$max_cycle_len);
322+
321323
let assemblyFormat = "attr-dict";
322324
}
323325

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
376376
/// ops over wider types.
377377
/// Warning: these patterns currently only work for little endian targets.
378378
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
379-
PatternBenefit benefit = 1);
379+
PatternBenefit benefit = 1,
380+
unsigned shiftDepth = 0);
380381

381382
/// Appends patterns for emulating a sub-byte vector transpose.
382383
void populateVectorTransposeNarrowTypeRewritePatterns(

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
166166

167167
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
168168
RewritePatternSet &patterns) {
169-
populateVectorNarrowTypeRewritePatterns(patterns);
169+
populateVectorNarrowTypeRewritePatterns(patterns, /*default=*/1,
170+
getMaxCycleLen());
170171
populateVectorTransposeNarrowTypeRewritePatterns(patterns);
171172
}
172173

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 176 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
546546
/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
547547
/// [0] = {0, [0, 10)}, {1, [0, 5)}
548548
/// [1] = {1, [5, 10)}, {2, [0, 10)}
549+
/// and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
550+
/// [0] = {0, [0, 4)}, {1, [0, 4)}
551+
/// [1] = {2, [0, 4)}, {3, [0, 4)}
552+
/// and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
553+
/// [0] = {0, [0, 4)}
554+
/// [1] = {0, [4, 8)}
555+
/// [2] = {1, [0, 4)}
556+
/// [3] = {1, [4, 8)}
549557
struct BitCastBitsEnumerator {
550558
BitCastBitsEnumerator(VectorType sourceVectorType,
551559
VectorType targetVectorType);
@@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
633641
/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
634642
/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
635643
/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
644+
///
645+
///
646+
/// When we consider the above algorithm to rewrite our vector.bitcast, we rely
647+
/// on using dynamic shift amounts for the left and right shifts. This can be
648+
/// inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
649+
/// value. So when possible we can rewrite this as a combination of shifts with
650+
/// a constant splat value and then regroup the selected terms.
651+
///
652+
/// Eg. Instead of:
653+
/// res = arith.shrui x [0, 4, 8, 0, 4, 8]
654+
/// use:
655+
/// y = arith.shrui x [0, 0, 0, 0, 0, 0] (can be folded away)
656+
/// y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
657+
/// y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
658+
/// y3 = vector.shuffle y y1 [0, 7, 3, 10]
659+
/// res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
660+
///
661+
/// This is possible when the precomputed shift amounts following a cyclic
662+
/// pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
663+
/// cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
664+
/// of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
665+
/// (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
666+
/// the same 2 steps as above, then it proceeds as follows:
667+
///
668+
/// 2. for each element in the cycle, x, of the rightShiftAmounts create a
669+
/// shrui with a splat constant of x.
670+
/// 3. repeat 2. with the respective leftShiftAmounts
671+
/// 4. construct a chain of vector.shuffles that will reconstruct the result
672+
/// from the chained shifts
636673
struct BitCastRewriter {
637674
/// Helper metadata struct to hold the static quantities for the rewrite.
638675
struct Metadata {
@@ -656,10 +693,25 @@ struct BitCastRewriter {
656693
Value initialValue, Value runningResult,
657694
const BitCastRewriter::Metadata &metadata);
658695

696+
/// Rewrite one step of the sequence when able to use a splat constant for the
697+
/// shiftright and shiftleft.
698+
Value splatRewriteStep(PatternRewriter &rewriter, Location loc,
699+
Value initialValue, Value runningResult,
700+
const BitCastRewriter::Metadata &metadata);
701+
702+
bool useSplatStep(unsigned maxCycleLen) {
703+
return 1 < cycleLen && cycleLen <= maxCycleLen;
704+
}
705+
659706
private:
660707
/// Underlying enumerator that encodes the provenance of the bits in the each
661708
/// element of the result vector.
662709
BitCastBitsEnumerator enumerator;
710+
711+
// Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
712+
// denotes that there is a cycle in the precomputed shift amounts and we are
713+
// able to use the splatRewriteStep.
714+
int64_t cycleLen = 0;
663715
};
664716

665717
} // namespace
@@ -775,8 +827,40 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
775827
return success();
776828
}
777829

830+
// Check if the vector is a cycle of the first cycleLen elements.
831+
template <class T>
832+
static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
833+
for (int64_t idx = cycleLen, n = xs.size(); idx < n; idx++) {
834+
if (xs[idx] != xs[idx % cycleLen])
835+
return false;
836+
}
837+
return true;
838+
}
839+
840+
static SmallVector<int64_t> constructShuffles(int64_t inputSize,
841+
int64_t numCycles,
842+
int64_t cycleLen, int64_t idx) {
843+
// If idx == 1, then the first operand of the shuffle will be the mask which
844+
// will have the original size. So we need to step through the mask with a
845+
// stride of cycleSize.
846+
// When idx > 1, then the first operand will be the size of (idx * cycleSize)
847+
// and so we take the first idx elements of the input and then append the
848+
// strided mask value.
849+
int64_t inputStride = idx == 1 ? cycleLen : idx;
850+
851+
SmallVector<int64_t> shuffles;
852+
for (int64_t cycle = 0; cycle < numCycles; cycle++) {
853+
for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
854+
shuffles.push_back(cycle * inputStride + inputIdx);
855+
}
856+
shuffles.push_back(inputSize + cycle * cycleLen + idx);
857+
}
858+
return shuffles;
859+
}
860+
778861
SmallVector<BitCastRewriter::Metadata>
779862
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
863+
bool cyclicShifts = true;
780864
SmallVector<BitCastRewriter::Metadata> result;
781865
for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
782866
shuffleIdx < e; ++shuffleIdx) {
@@ -811,8 +895,71 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
811895
IntegerAttr::get(shuffledElementType, shiftLeft));
812896
}
813897

898+
// Compute a potential cycle size by detecting the number of sourceElements
899+
// at the start of shuffle that are the same
900+
cycleLen = 1;
901+
for (int64_t n = shuffles.size(); cycleLen < n; cycleLen++)
902+
if (shuffles[cycleLen] != shuffles[0])
903+
break;
904+
905+
cyclicShifts = cyclicShifts && (cycleLen < (int64_t)shuffles.size()) &&
906+
isCyclic(shiftRightAmounts, cycleLen) &&
907+
isCyclic(shiftLeftAmounts, cycleLen);
908+
814909
result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
815910
}
911+
912+
cycleLen = cyclicShifts ? cycleLen : 0;
913+
return result;
914+
}
915+
916+
Value BitCastRewriter::splatRewriteStep(
917+
PatternRewriter &rewriter, Location loc, Value initialValue,
918+
Value runningResult, const BitCastRewriter::Metadata &metadata) {
919+
920+
// Initial result will be the Shifted Mask which will have the shuffles size.
921+
int64_t inputSize = metadata.shuffles.size();
922+
int64_t numCycles = inputSize / cycleLen;
923+
924+
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
925+
loc, initialValue, initialValue, metadata.shuffles);
926+
927+
// Intersect with the mask.
928+
VectorType shuffledVectorType = shuffleOp.getResultVectorType();
929+
auto constOp = rewriter.create<arith::ConstantOp>(
930+
loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
931+
Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
932+
933+
Value result;
934+
for (int64_t idx = 0; idx < cycleLen; idx++) {
935+
auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
936+
loc, SplatElementsAttr::get(shuffledVectorType,
937+
metadata.shiftRightAmounts[idx]));
938+
Value shiftedRight =
939+
rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
940+
941+
auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
942+
loc, SplatElementsAttr::get(shuffledVectorType,
943+
metadata.shiftLeftAmounts[idx]));
944+
Value shiftedLeft =
945+
rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
946+
947+
if (result) {
948+
SmallVector<int64_t> shuffles =
949+
constructShuffles(inputSize, numCycles, cycleLen, idx);
950+
result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
951+
shuffles);
952+
953+
// After the first shuffle in the chain, the size of the input result will
954+
// grow as we append more shuffles together to reconstruct the
955+
// shuffledVectorType size. Each iteration they will retain numCycles more
956+
// elements.
957+
inputSize = numCycles * (idx + 1);
958+
} else {
959+
result = shiftedLeft;
960+
}
961+
}
962+
816963
return result;
817964
}
818965

@@ -939,6 +1086,11 @@ namespace {
9391086
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9401087
using OpRewritePattern::OpRewritePattern;
9411088

1089+
RewriteBitCastOfTruncI(MLIRContext *context, PatternBenefit benefit,
1090+
unsigned maxCycleLen)
1091+
: OpRewritePattern<vector::BitCastOp>(context, benefit),
1092+
maxCycleLen{maxCycleLen} {}
1093+
9421094
LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
9431095
PatternRewriter &rewriter) const override {
9441096
// The source must be a trunc op.
@@ -961,8 +1113,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9611113
Value runningResult;
9621114
for (const BitCastRewriter ::Metadata &metadata :
9631115
bcr.precomputeMetadata(shuffledElementType)) {
964-
runningResult = bcr.genericRewriteStep(
965-
rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1116+
runningResult =
1117+
bcr.useSplatStep(maxCycleLen)
1118+
? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
1119+
runningResult, metadata)
1120+
: bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
1121+
truncValue, runningResult, metadata);
9661122
}
9671123

9681124
// Finalize the rewrite.
@@ -986,6 +1142,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
9861142

9871143
return success();
9881144
}
1145+
1146+
private:
1147+
unsigned maxCycleLen;
9891148
};
9901149
} // namespace
9911150

@@ -1001,8 +1160,10 @@ template <typename ExtOpType>
10011160
struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10021161
using OpRewritePattern<ExtOpType>::OpRewritePattern;
10031162

1004-
RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1005-
: OpRewritePattern<ExtOpType>(context, benefit) {}
1163+
RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit,
1164+
unsigned maxCycleLen)
1165+
: OpRewritePattern<ExtOpType>(context, benefit),
1166+
maxCycleLen{maxCycleLen} {}
10061167

10071168
LogicalResult matchAndRewrite(ExtOpType extOp,
10081169
PatternRewriter &rewriter) const override {
@@ -1026,8 +1187,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10261187
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
10271188
for (const BitCastRewriter::Metadata &metadata :
10281189
bcr.precomputeMetadata(shuffledElementType)) {
1029-
runningResult = bcr.genericRewriteStep(
1030-
rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1190+
runningResult =
1191+
bcr.useSplatStep(maxCycleLen)
1192+
? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), sourceValue,
1193+
runningResult, metadata)
1194+
: bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
1195+
sourceValue, runningResult, metadata);
10311196
}
10321197

10331198
// Finalize the rewrite.
@@ -1044,6 +1209,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10441209

10451210
return success();
10461211
}
1212+
1213+
private:
1214+
unsigned maxCycleLen;
10471215
};
10481216

10491217
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
@@ -1222,10 +1390,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
12221390
}
12231391

12241392
void vector::populateVectorNarrowTypeRewritePatterns(
1225-
RewritePatternSet &patterns, PatternBenefit benefit) {
1393+
RewritePatternSet &patterns, PatternBenefit benefit, unsigned maxCycleLen) {
12261394
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
12271395
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1228-
benefit);
1396+
benefit, maxCycleLen);
12291397

12301398
// Patterns for aligned cases. We set higher priority as they are expected to
12311399
// generate better performance for aligned cases.

0 commit comments

Comments
 (0)