Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
Warning: these patterns currently only work for little endian targets.
}];

let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$max_cycle_len);

let assemblyFormat = "attr-dict";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
/// ops over wider types.
/// Warning: these patterns currently only work for little endian targets.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
PatternBenefit benefit = 1,
unsigned shiftDepth = 0);

/// Appends patterns for emulating a sub-byte vector transpose.
void populateVectorTransposeNarrowTypeRewritePatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(

void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
populateVectorNarrowTypeRewritePatterns(patterns, /*default=*/1,
getMaxCycleLen());
populateVectorTransposeNarrowTypeRewritePatterns(patterns);
}

Expand Down
159 changes: 151 additions & 8 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
/// [0] = {0, [0, 10)}, {1, [0, 5)}
/// [1] = {1, [5, 10)}, {2, [0, 10)}
/// and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
/// [0] = {0, [0, 4)}, {1, [0, 4)}
/// [1] = {2, [0, 4)}, {3, [0, 4)}
/// and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
/// [0] = {0, [0, 4)}
/// [1] = {0, [4, 8)}
/// [2] = {1, [0, 4)}
/// [3] = {1, [4, 8)}
struct BitCastBitsEnumerator {
BitCastBitsEnumerator(VectorType sourceVectorType,
VectorType targetVectorType);
Expand Down Expand Up @@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
///
///
/// When we consider the above algorithm to rewrite our vector.bitcast, we rely
/// on using dynamic shift amounts for the left and right shifts. This can be
/// inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
/// value. So when possible we can rewrite this as a combination of shifts with
/// a constant splat value and then regroup the selected terms.
///
/// Eg. Instead of:
/// res = arith.shrui x [0, 4, 8, 0, 4, 8]
/// use:
/// y = arith.shrui x [0, 0, 0, 0, 0, 0] (can be folded away)
/// y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
/// y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
/// y3 = vector.shuffle y y1 [0, 7, 3, 10]
/// res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
///
/// This is possible when the precomputed shift amounts following a cyclic
/// pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
/// cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
/// of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
/// (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
/// the same 2 steps as above, then it proceeds as follows:
///
/// 2. for each element in the cycle, x, of the rightShiftAmounts create a
/// shrui with a splat constant of x.
/// 3. repeat 2. with the respective leftShiftAmounts
/// 4. construct a chain of vector.shuffles that will reconstruct the result
/// from the chained shifts
struct BitCastRewriter {
/// Helper metadata struct to hold the static quantities for the rewrite.
struct Metadata {
Expand All @@ -656,10 +693,25 @@ struct BitCastRewriter {
Value initialValue, Value runningResult,
const BitCastRewriter::Metadata &metadata);

/// Rewrite one step of the sequence when able to use a splat constant for the
/// shiftright and shiftleft.
Value splatRewriteStep(PatternRewriter &rewriter, Location loc,
Value initialValue, Value runningResult,
const BitCastRewriter::Metadata &metadata);

bool useSplatStep(unsigned maxCycleLen) {
return 1 < cycleLen && cycleLen <= maxCycleLen;
}

private:
/// Underlying enumerator that encodes the provenance of the bits in the each
/// element of the result vector.
BitCastBitsEnumerator enumerator;

// Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
// denotes that there is a cycle in the precomputed shift amounts and we are
// able to use the splatRewriteStep.
int64_t cycleLen = 0;
};

} // namespace
Expand Down Expand Up @@ -775,8 +827,31 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
return success();
}

// Check if the vector is a cycle of the first cycleLen elements.
template <class T>
static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
for (int64_t idx = cycleLen, n = xs.size(); idx < n; idx++) {
if (xs[idx] != xs[idx % cycleLen])
return false;
}
return true;
}

static SmallVector<int64_t> constructShuffles(int64_t numCycles,
int64_t cycleLen, int64_t idx) {
SmallVector<int64_t> shuffles;
for (int64_t cycle = 0; cycle < numCycles; cycle++) {
for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
shuffles.push_back(cycle * idx + inputIdx);
}
shuffles.push_back(numCycles * idx + cycle);
}
return shuffles;
}

SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
bool cyclicShifts = true;
SmallVector<BitCastRewriter::Metadata> result;
for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
shuffleIdx < e; ++shuffleIdx) {
Expand Down Expand Up @@ -811,8 +886,55 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
IntegerAttr::get(shuffledElementType, shiftLeft));
}

// Compute a potential cycle size by detecting the number of sourceElements
// at the start of shuffle that are the same
cycleLen = 1;
for (int64_t n = shuffles.size(); cycleLen < n; cycleLen++)
if (shuffles[cycleLen] != shuffles[0])
break;

cyclicShifts = cyclicShifts && (cycleLen < (int64_t)shuffles.size()) &&
isCyclic(shiftRightAmounts, cycleLen) &&
isCyclic(shiftLeftAmounts, cycleLen);

result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
}

cycleLen = cyclicShifts ? cycleLen : 0;
return result;
}

Value BitCastRewriter::splatRewriteStep(
PatternRewriter &rewriter, Location loc, Value initialValue,
Value runningResult, const BitCastRewriter::Metadata &metadata) {

int64_t numCycles = metadata.shuffles.size() / cycleLen;
ShapedType vectorType = dyn_cast<ShapedType>(initialValue.getType());
Value result;
for (int64_t idx = 0; idx < cycleLen; idx++) {
// Intersect with the mask.
auto constOp = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(vectorType, metadata.masks[idx]));
Value andValue = rewriter.create<arith::AndIOp>(loc, initialValue, constOp);

auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
loc,
SplatElementsAttr::get(vectorType, metadata.shiftRightAmounts[idx]));
Value shiftedRight =
rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);

auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
loc,
SplatElementsAttr::get(vectorType, metadata.shiftLeftAmounts[idx]));
Value shiftedLeft =
rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);

SmallVector<int64_t> shuffles = constructShuffles(numCycles, cycleLen, idx);
result = result ? rewriter.create<vector::ShuffleOp>(loc, result,
shiftedLeft, shuffles)
: shiftedLeft;
}

return result;
}

Expand Down Expand Up @@ -939,6 +1061,11 @@ namespace {
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;

RewriteBitCastOfTruncI(MLIRContext *context, PatternBenefit benefit,
unsigned maxCycleLen)
: OpRewritePattern<vector::BitCastOp>(context, benefit),
maxCycleLen{maxCycleLen} {}

LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
PatternRewriter &rewriter) const override {
// The source must be a trunc op.
Expand All @@ -961,8 +1088,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
Value runningResult;
for (const BitCastRewriter ::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
runningResult = bcr.genericRewriteStep(
rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
runningResult =
bcr.useSplatStep(maxCycleLen)
? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
runningResult, metadata)
: bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
truncValue, runningResult, metadata);
}

// Finalize the rewrite.
Expand All @@ -986,6 +1117,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {

return success();
}

private:
unsigned maxCycleLen;
};
} // namespace

Expand All @@ -1001,8 +1135,10 @@ template <typename ExtOpType>
struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
using OpRewritePattern<ExtOpType>::OpRewritePattern;

RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
: OpRewritePattern<ExtOpType>(context, benefit) {}
RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit,
unsigned maxCycleLen)
: OpRewritePattern<ExtOpType>(context, benefit),
maxCycleLen{maxCycleLen} {}

LogicalResult matchAndRewrite(ExtOpType extOp,
PatternRewriter &rewriter) const override {
Expand All @@ -1026,8 +1162,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
for (const BitCastRewriter::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
runningResult = bcr.genericRewriteStep(
rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
runningResult =
bcr.useSplatStep(maxCycleLen)
? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), sourceValue,
runningResult, metadata)
: bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
sourceValue, runningResult, metadata);
}

// Finalize the rewrite.
Expand All @@ -1044,6 +1184,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {

return success();
}

private:
unsigned maxCycleLen;
};

/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
Expand Down Expand Up @@ -1222,10 +1365,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
}

void vector::populateVectorNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, PatternBenefit benefit, unsigned maxCycleLen) {
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
benefit);
benefit, maxCycleLen);

// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
Expand Down
82 changes: 81 additions & 1 deletion mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,42 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
return %1 : vector<8xi6>
}

// CHECK-LABEL: func.func @ftrunc_splat1(
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi16>) -> vector<1xi8> {
func.func @ftrunc_splat1(%a: vector<2xi16>) -> vector<1xi8> {
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<1xi16>
// CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<1xi16>
// CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0] : vector<2xi16>, vector<2xi16>
// CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<1xi16>
// CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1] : vector<2xi16>, vector<2xi16>
// CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<1xi16>
// CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<1xi16>
// CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<1xi16>
// CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<1xi16> to vector<1xi8>
// return %[[RES]] : vector<1xi8>
%0 = arith.trunci %a : vector<2xi16> to vector<2xi4>
%1 = vector.bitcast %0 : vector<2xi4> to vector<1xi8>
return %1 : vector<1xi8>
}

// CHECK-LABEL: func.func @ftrunc_splat2(
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<2xi8> {
func.func @ftrunc_splat2(%a: vector<4xi16>) -> vector<2xi8> {
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<2xi16>
// CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<2xi16>
// CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 2] : vector<4xi16>, vector<4xi16>
// CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<2xi16>
// CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 3] : vector<4xi16>, vector<4xi16>
// CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<2xi16>
// CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<2xi16>
// CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<2xi16>
// CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<2xi16> to vector<2xi8>
// return %[[RES]] : vector<2xi8>
%0 = arith.trunci %a : vector<4xi16> to vector<4xi4>
%1 = vector.bitcast %0 : vector<4xi4> to vector<2xi8>
return %1 : vector<2xi8>
}

// CHECK-LABEL: func.func @f1ext(
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> {
func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
Expand Down Expand Up @@ -193,6 +229,50 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}

// CHECK-LABEL: func.func @fext_splat1(
// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
// CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<2xi8>
// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<-16> : vector<2xi8>
// CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<2xi8>
// CHECK-DAG: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<2xi8>
// CHECK-DAG: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<2xi8>
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST]] : vector<2xi8>
// CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 2, 1, 3] : vector<2xi8>, vector<2xi8>
// CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
// return %[[RES]] : vector<4xi16>
%0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
%1 = arith.extui %0 : vector<4xi4> to vector<4xi16>
return %1 : vector<4xi16>
}

// CHECK-LABEL: func.func @fext_splat2(
// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
// CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<3xi16>
// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<240> : vector<3xi16>
// CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<3840> : vector<3xi16>
// CHECK-DAG: %[[MASK3:.*]] = arith.constant dense<-4096> : vector<3xi16>
// CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<3xi16>
// CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<3xi16>
// CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<3xi16>
// CHECK: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<3xi16>
// CHECK: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<3xi16>
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST0]] : vector<3xi16>
// CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 3, 1, 4, 2, 5] : vector<3xi16>, vector<3xi16>
// CHECK: %[[A2:.*]] = arith.andi %[[ARG]], %[[MASK2]] : vector<3xi16>
// CHECK: %[[SHR1:.*]] = arith.shrui %[[A2]], %[[SHR_CST1]] : vector<3xi16>
// CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 6, 2, 3, 7, 4, 5, 8] : vector<6xi16>, vector<3xi16>
// CHECK: %[[A3:.*]] = arith.andi %[[ARG]], %[[MASK3]] : vector<3xi16>
// CHECK: %[[SHR2:.*]] = arith.shrui %[[A3]], %[[SHR_CST2]] : vector<3xi16>
// CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 9, 3, 4, 5, 10, 6, 7, 8, 11] : vector<9xi16>, vector<3xi16>
// CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
// CHEKC: return %[[RES]] : vector<12xi32>
%0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>
%1 = arith.extui %0 : vector<12xi4> to vector<12xi32>
return %1 : vector<12xi32>
}

// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
Expand Down Expand Up @@ -330,7 +410,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op

transform.apply_patterns to %f {
transform.apply_patterns.vector.rewrite_narrow_types
transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 }
} : !transform.any_op
transform.yield
}
Expand Down
Loading