File tree Expand file tree Collapse file tree 1 file changed +13
-12
lines changed
mlir/lib/Dialect/Vector/IR Expand file tree Collapse file tree 1 file changed +13
-12
lines changed Original file line number Diff line number Diff line change @@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
29462946
29472947 LogicalResult matchAndRewrite (BroadcastOp broadcastOp,
29482948 PatternRewriter &rewriter) const override {
2949- if (auto srcShapeCast =
2950- broadcastOp.getSource ().getDefiningOp <ShapeCastOp>()) {
2951- VectorType srcType = srcShapeCast.getSourceVectorType ();
2952- VectorType destType = broadcastOp.getResultVectorType ();
2953- if (vector::isBroadcastableTo (srcType, destType) ==
2954- BroadcastableToResult::Success) {
2955- rewriter.replaceOpWithNewOp <BroadcastOp>(broadcastOp, destType,
2956- srcShapeCast.getSource ());
2957- return success ();
2958- }
2959- }
2960- return failure ();
2949+ auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
2950+ if (!srcShapeCast)
2951+ return failure ();
2952+
2953+ VectorType srcType = srcShapeCast.getSourceVectorType ();
2954+ VectorType destType = broadcastOp.getResultVectorType ();
2955+ if (vector::isBroadcastableTo (srcType, destType) !=
2956+ BroadcastableToResult::Success)
2957+ return failure ();
2958+
2959+ rewriter.replaceOpWithNewOp <BroadcastOp>(broadcastOp, destType,
2960+ srcShapeCast.getSource ());
2961+ return success ();
29612962 }
29622963};
29632964} // namespace
You can’t perform that action at this time.
0 commit comments