Skip to content

Commit 10a914e

Browse files
committed
fixup! Address review comments
1 parent 9ca07a1 commit 10a914e

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
29462946

29472947
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
29482948
PatternRewriter &rewriter) const override {
2949-
if (auto srcShapeCast =
2950-
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
2951-
VectorType srcType = srcShapeCast.getSourceVectorType();
2952-
VectorType destType = broadcastOp.getResultVectorType();
2953-
if (vector::isBroadcastableTo(srcType, destType) ==
2954-
BroadcastableToResult::Success) {
2955-
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
2956-
srcShapeCast.getSource());
2957-
return success();
2958-
}
2959-
}
2960-
return failure();
2949+
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2950+
if (!srcShapeCast)
2951+
return failure();
2952+
2953+
VectorType srcType = srcShapeCast.getSourceVectorType();
2954+
VectorType destType = broadcastOp.getResultVectorType();
2955+
if (vector::isBroadcastableTo(srcType, destType) !=
2956+
BroadcastableToResult::Success)
2957+
return failure();
2958+
2959+
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
2960+
srcShapeCast.getSource());
2961+
return success();
29612962
}
29622963
};
29632964
} // namespace

0 commit comments

Comments
 (0)