@@ -938,7 +938,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
938938
939939 Value zero = rewriter.create <arith::ConstantOp>(
940940 loc, elemType, rewriter.getZeroAttr (elemType));
941- Value res = rewriter.create <SplatOp >(loc, castDstType, zero);
941+ Value res = rewriter.create <BroadcastOp >(loc, castDstType, zero);
942942
943943 SmallVector<int64_t > sliceShape = {castDstLastDim};
944944 SmallVector<int64_t > strides = {1 };
@@ -964,6 +964,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
964964 std::function<bool (BitCastOp)> controlFn;
965965};
966966
967+ // / If \p v is result of a splat or broadcast operation, return the input of the
968+ // / broadcast/splat operation.
969+ static Value getBroadcastLikeSource (Value v) {
970+
971+ Operation *op = v.getDefiningOp ();
972+ if (!op)
973+ return {};
974+
975+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
976+ return broadcast.getSource ();
977+
978+ if (auto splat = dyn_cast<vector::SplatOp>(op))
979+ return splat.getInput ();
980+
981+ return {};
982+ }
983+
967984// / Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
968985// /
969986// / Example:
@@ -1005,26 +1022,23 @@ struct ReorderElementwiseOpsOnBroadcast final
10051022 }
10061023
10071024 // Get the type of the lhs operand
1008- auto *lhsBcastOrSplat = op->getOperand (0 ). getDefiningOp ( );
1009- if (!lhsBcastOrSplat ||
1010- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1011- return failure ( );
1012- auto lhsBcastOrSplatType = lhsBcastOrSplat-> getOperand ( 0 ) .getType ();
1025+ Value lhsSource = getBroadcastLikeSource ( op->getOperand (0 ));
1026+ if (!lhsSource)
1027+ return rewriter. notifyMatchFailure (
1028+ op, " operand #0 not the result of a broadcast " );
1029+ Type lhsBcastOrSplatType = lhsSource .getType ();
10131030
10141031 // Make sure that all operands are broadcast from identical types:
10151032 // * scalar (`vector.broadcast` + `vector.splat`), or
10161033 // * vector (`vector.broadcast`).
10171034 // Otherwise the re-ordering wouldn't be safe.
1018- if (!llvm::all_of (op->getOperands (), [&lhsBcastOrSplatType](Value val) {
1019- auto bcast = val.getDefiningOp <vector::BroadcastOp>();
1020- if (bcast)
1021- return (bcast.getOperand ().getType () == lhsBcastOrSplatType);
1022- auto splat = val.getDefiningOp <vector::SplatOp>();
1023- if (splat)
1024- return (splat.getOperand ().getType () == lhsBcastOrSplatType);
1035+ if (!llvm::all_of (op->getOperands (), [lhsBcastOrSplatType](Value val) {
1036+ if (auto source = getBroadcastLikeSource (val))
1037+ return source.getType () == lhsBcastOrSplatType;
10251038 return false ;
10261039 })) {
1027- return failure ();
1040+ return rewriter.notifyMatchFailure (
1041+ op, " not all operands are broadcasts from the sametype" );
10281042 }
10291043
10301044 // Collect the source values before broadcasting
@@ -1232,15 +1246,17 @@ class StoreOpFromSplatOrBroadcast final
12321246 return rewriter.notifyMatchFailure (
12331247 op, " only 1-element vectors are supported" );
12341248
1235- Operation *splat = op.getValueToStore ().getDefiningOp ();
1236- if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1237- return rewriter.notifyMatchFailure (op, " neither a splat nor a broadcast" );
1249+ Value toStore = op.getValueToStore ();
1250+ Value source = getBroadcastLikeSource (toStore);
1251+ if (!source)
1252+ return rewriter.notifyMatchFailure (
1253+ op, " value to store is not from a broadcast" );
12381254
12391255 // Checking for single use so we can remove splat.
1256+ Operation *splat = toStore.getDefiningOp ();
12401257 if (!splat->hasOneUse ())
12411258 return rewriter.notifyMatchFailure (op, " expected single op use" );
12421259
1243- Value source = splat->getOperand (0 );
12441260 Value base = op.getBase ();
12451261 ValueRange indices = op.getIndices ();
12461262
@@ -1290,13 +1306,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
12901306 // Add in an offset if requested.
12911307 if (off) {
12921308 Value o = getValueOrCreateCastToIndexLike (rewriter, loc, idxType, *off);
1293- Value ov = rewriter.create <vector::SplatOp >(loc, indices.getType (), o);
1309+ Value ov = rewriter.create <vector::BroadcastOp >(loc, indices.getType (), o);
12941310 indices = rewriter.create <arith::AddIOp>(loc, ov, indices);
12951311 }
12961312 // Construct the vector comparison.
12971313 Value bound = getValueOrCreateCastToIndexLike (rewriter, loc, idxType, b);
12981314 Value bounds =
1299- rewriter.create <vector::SplatOp >(loc, indices.getType (), bound);
1315+ rewriter.create <vector::BroadcastOp >(loc, indices.getType (), bound);
13001316 return rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
13011317 bounds);
13021318}
0 commit comments