Skip to content

Commit e370b81

Browse files
committed
fixup! Simplify the algorithm for the legality check
1 parent 236c545 commit e370b81

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,7 +2842,7 @@ LogicalResult BroadcastOp::verify() {
28422842
}
28432843

28442844
// Return the broadcasted dimensions. Including broadcasts in the leading
2845-
// dimensions and broadcasts through unit dimension (i.e. dim-1).
2845+
// dimensions and broadcasts through unit dimension.
28462846
static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
28472847
ArrayRef<int64_t> destShape) {
28482848
assert(destShape.size() >= srcShape.size());
@@ -2855,7 +2855,8 @@ static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
28552855
}
28562856

28572857
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2858-
// with broadcast's result type and the broadcasted dimensions are the same.
2858+
// with broadcast's result type and shape_cast only adds or removes ones in the
2859+
// leading dimensions.
28592860
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
28602861
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
28612862
if (!srcShapeCast)
@@ -2868,22 +2869,22 @@ static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
28682869
BroadcastableToResult::Success)
28692870
return failure();
28702871

2871-
// Given
2872-
// ```
2873-
// %s = shape_cast(%x)
2874-
// %b = broadcast(%s)
2875-
// ```
2876-
// If we want to fold %x into %b, the broadcasted dimensions from %x to
2877-
// %b has to be the same as that of from %s to %b.
2872+
ArrayRef<int64_t> srcShape = srcType.getShape();
28782873
ArrayRef<int64_t> shapecastShape =
28792874
srcShapeCast.getResultVectorType().getShape();
2880-
ArrayRef<int64_t> srcShape = srcType.getShape();
2881-
ArrayRef<int64_t> destShape = destType.getShape();
2882-
BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape);
2883-
BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
2884-
if (newBroadcastedDims != origBroadcastedDims)
2875+
// Trailing dimensions should be the same if shape_cast only alters the
2876+
// leading dimensions.
2877+
unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
2878+
if (!llvm::equal(srcShape.take_back(numTrailingDims),
2879+
shapecastShape.take_back(numTrailingDims)))
28852880
return failure();
28862881

2882+
assert(all_of(srcShape.drop_back(numTrailingDims),
2883+
[](int64_t E) { return E == 1; }) &&
2884+
all_of(shapecastShape.drop_back(numTrailingDims),
2885+
[](int64_t E) { return E == 1; }) &&
2886+
"ill-formed shape_cast");
2887+
28872888
broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
28882889
return success();
28892890
}

0 commit comments

Comments
 (0)