Skip to content

Commit 236c545

Browse files
committed
fixup! Rewrite as a folding pattern
1 parent 0cf5cc1 commit 236c545

File tree

1 file changed

+51
-55
lines changed

1 file changed

+51
-55
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
28412841
llvm_unreachable("unexpected vector.broadcast op error");
28422842
}
28432843

2844+
// Return the broadcasted dimensions. Including broadcasts in the leading
2845+
// dimensions and broadcasts through unit dimension (i.e. dim-1).
2846+
static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
2847+
ArrayRef<int64_t> destShape) {
2848+
assert(destShape.size() >= srcShape.size());
2849+
BitVector broadcastedDims(destShape.size());
2850+
broadcastedDims.set(0, destShape.size() - srcShape.size());
2851+
auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
2852+
for (int64_t dim : unitDims)
2853+
broadcastedDims.set(dim);
2854+
return broadcastedDims;
2855+
}
2856+
2857+
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2858+
// with broadcast's result type and the broadcasted dimensions are the same.
2859+
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
2860+
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2861+
if (!srcShapeCast)
2862+
return failure();
2863+
2864+
VectorType srcType = srcShapeCast.getSourceVectorType();
2865+
VectorType destType = broadcastOp.getResultVectorType();
2866+
// Check type compatibility.
2867+
if (vector::isBroadcastableTo(srcType, destType) !=
2868+
BroadcastableToResult::Success)
2869+
return failure();
2870+
2871+
// Given
2872+
// ```
2873+
// %s = shape_cast(%x)
2874+
// %b = broadcast(%s)
2875+
// ```
2876+
// If we want to fold %x into %b, the broadcasted dimensions from %x to
2877+
// %b has to be the same as that of from %s to %b.
2878+
ArrayRef<int64_t> shapecastShape =
2879+
srcShapeCast.getResultVectorType().getShape();
2880+
ArrayRef<int64_t> srcShape = srcType.getShape();
2881+
ArrayRef<int64_t> destShape = destType.getShape();
2882+
BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape);
2883+
BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
2884+
if (newBroadcastedDims != origBroadcastedDims)
2885+
return failure();
2886+
2887+
broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
2888+
return success();
2889+
}
2890+
28442891
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
28452892
if (getSourceType() == getResultVectorType())
28462893
return getSource();
2894+
if (succeeded(foldBroadcastOfShapeCast(*this)))
2895+
return getResult();
2896+
28472897
if (!adaptor.getSource())
28482898
return {};
28492899
auto vectorType = getResultVectorType();
@@ -2881,67 +2931,13 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
28812931
return success();
28822932
}
28832933
};
2884-
2885-
// Return the broadcasted dimensions. Including broadcasts in the leading
2886-
// dimensions and broadcasts through unit dimension (i.e. dim-1).
2887-
static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
2888-
ArrayRef<int64_t> destShape) {
2889-
assert(destShape.size() >= srcShape.size());
2890-
BitVector broadcastedDims(destShape.size());
2891-
broadcastedDims.set(0, destShape.size() - srcShape.size());
2892-
auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
2893-
for (int64_t dim : unitDims)
2894-
broadcastedDims.set(dim);
2895-
return broadcastedDims;
2896-
}
2897-
2898-
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2899-
// with broadcast's result type and the broadcasted dimensions are the same.
2900-
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
2901-
using OpRewritePattern::OpRewritePattern;
2902-
2903-
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2904-
PatternRewriter &rewriter) const override {
2905-
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2906-
if (!srcShapeCast)
2907-
return failure();
2908-
2909-
VectorType srcType = srcShapeCast.getSourceVectorType();
2910-
VectorType destType = broadcastOp.getResultVectorType();
2911-
// Check type compatibility.
2912-
if (vector::isBroadcastableTo(srcType, destType) !=
2913-
BroadcastableToResult::Success)
2914-
return failure();
2915-
2916-
// Given
2917-
// ```
2918-
// %s = shape_cast(%x)
2919-
// %b = broadcast(%s)
2920-
// ```
2921-
// If we want to fold %x into %b, the broadcasted dimensions from %x to
2922-
// %b has to be the same as that of from %s to %b.
2923-
ArrayRef<int64_t> shapecastShape =
2924-
srcShapeCast.getResultVectorType().getShape();
2925-
ArrayRef<int64_t> srcShape = srcType.getShape();
2926-
ArrayRef<int64_t> destShape = destType.getShape();
2927-
BitVector origBroadcastedDims =
2928-
getBroadcastedDims(shapecastShape, destShape);
2929-
BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
2930-
if (newBroadcastedDims != origBroadcastedDims)
2931-
return failure();
2932-
2933-
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
2934-
srcShapeCast.getSource());
2935-
return success();
2936-
}
2937-
};
29382934
} // namespace
29392935

29402936
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
29412937
MLIRContext *context) {
29422938
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
29432939
// calling `populateCastAwayVectorLeadingOneDimPatterns`
2944-
results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
2940+
results.add<BroadcastFolder>(context);
29452941
}
29462942

29472943
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)