Skip to content

Commit 9ca07a1

Browse files
committed
[mlir][vector] Canonicalize broadcast of shape_cast
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is compatible with broadcast's result type.
1 parent f26c0d0 commit 9ca07a1

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
29382938
return success();
29392939
}
29402940
};
2941+
2942+
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2943+
// with broadcast's result type.
2944+
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
2945+
using OpRewritePattern::OpRewritePattern;
2946+
2947+
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2948+
PatternRewriter &rewriter) const override {
2949+
if (auto srcShapeCast =
2950+
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
2951+
VectorType srcType = srcShapeCast.getSourceVectorType();
2952+
VectorType destType = broadcastOp.getResultVectorType();
2953+
if (vector::isBroadcastableTo(srcType, destType) ==
2954+
BroadcastableToResult::Success) {
2955+
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
2956+
srcShapeCast.getSource());
2957+
return success();
2958+
}
2959+
}
2960+
return failure();
2961+
}
2962+
};
29412963
} // namespace
29422964

29432965
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
29442966
MLIRContext *context) {
29452967
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
29462968
// calling `populateCastAwayVectorLeadingOneDimPatterns`
2947-
results.add<BroadcastFolder>(context);
2969+
results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
29482970
}
29492971

29502972
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
11681168

11691169
// -----
11701170

1171+
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
1172+
// CHECK-NOT: vector.shape_cast
1173+
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
1174+
func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
1175+
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
1176+
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
1177+
return %1 : vector<32x2xf32>
1178+
}
1179+
1180+
// -----
1181+
1182+
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
1183+
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
1184+
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
1185+
func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
1186+
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
1187+
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
1188+
return %1 : vector<2x4x16xf32>
1189+
}
1190+
1191+
// -----
1192+
11711193
// CHECK-LABEL: fold_vector_transfer_masks
11721194
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
11731195
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)