Skip to content

Commit 47bbdf1

Browse files
committed
update VectorTransforms.cpp to use broadcast
1 parent 5aa7757 commit 47bbdf1

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
938938

939939
Value zero = rewriter.create<arith::ConstantOp>(
940940
loc, elemType, rewriter.getZeroAttr(elemType));
941-
Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
941+
Value res = rewriter.create<BroadcastOp>(loc, castDstType, zero);
942942

943943
SmallVector<int64_t> sliceShape = {castDstLastDim};
944944
SmallVector<int64_t> strides = {1};
@@ -964,6 +964,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
964964
std::function<bool(BitCastOp)> controlFn;
965965
};
966966

967+
/// If \p v is result of a splat or broadcast operation, return the input of the
968+
/// broadcast/splat operation.
969+
static Value getBroadcastLikeSource(Value v) {
970+
971+
Operation *op = v.getDefiningOp();
972+
if (!op)
973+
return {};
974+
975+
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
976+
return broadcast.getSource();
977+
978+
if (auto splat = dyn_cast<vector::SplatOp>(op))
979+
return splat.getInput();
980+
981+
return {};
982+
}
983+
967984
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
968985
///
969986
/// Example:
@@ -1005,26 +1022,23 @@ struct ReorderElementwiseOpsOnBroadcast final
10051022
}
10061023

10071024
// Get the type of the lhs operand
1008-
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
1009-
if (!lhsBcastOrSplat ||
1010-
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1011-
return failure();
1012-
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1025+
Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
1026+
if (!lhsSource)
1027+
return rewriter.notifyMatchFailure(
1028+
op, "operand #0 not the result of a broadcast");
1029+
Type lhsBcastOrSplatType = lhsSource.getType();
10131030

10141031
// Make sure that all operands are broadcast from identical types:
10151032
// * scalar (`vector.broadcast` + `vector.splat`), or
10161033
// * vector (`vector.broadcast`).
10171034
// Otherwise the re-ordering wouldn't be safe.
1018-
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
1019-
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1020-
if (bcast)
1021-
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1022-
auto splat = val.getDefiningOp<vector::SplatOp>();
1023-
if (splat)
1024-
return (splat.getOperand().getType() == lhsBcastOrSplatType);
1035+
if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
1036+
if (auto source = getBroadcastLikeSource(val))
1037+
return source.getType() == lhsBcastOrSplatType;
10251038
return false;
10261039
})) {
1027-
return failure();
1040+
return rewriter.notifyMatchFailure(
1041+
op, "not all operands are broadcasts from the sametype");
10281042
}
10291043

10301044
// Collect the source values before broadcasting
@@ -1232,15 +1246,17 @@ class StoreOpFromSplatOrBroadcast final
12321246
return rewriter.notifyMatchFailure(
12331247
op, "only 1-element vectors are supported");
12341248

1235-
Operation *splat = op.getValueToStore().getDefiningOp();
1236-
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1237-
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
1249+
Value toStore = op.getValueToStore();
1250+
Value source = getBroadcastLikeSource(toStore);
1251+
if (!source)
1252+
return rewriter.notifyMatchFailure(
1253+
op, "value to store is not from a broadcast");
12381254

12391255
// Checking for single use so we can remove splat.
1256+
Operation *splat = toStore.getDefiningOp();
12401257
if (!splat->hasOneUse())
12411258
return rewriter.notifyMatchFailure(op, "expected single op use");
12421259

1243-
Value source = splat->getOperand(0);
12441260
Value base = op.getBase();
12451261
ValueRange indices = op.getIndices();
12461262

@@ -1290,13 +1306,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
12901306
// Add in an offset if requested.
12911307
if (off) {
12921308
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1293-
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1309+
Value ov = rewriter.create<vector::BroadcastOp>(loc, indices.getType(), o);
12941310
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
12951311
}
12961312
// Construct the vector comparison.
12971313
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
12981314
Value bounds =
1299-
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1315+
rewriter.create<vector::BroadcastOp>(loc, indices.getType(), bound);
13001316
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
13011317
bounds);
13021318
}

0 commit comments

Comments
 (0)