@@ -2842,7 +2842,7 @@ LogicalResult BroadcastOp::verify() {
2842
2842
}
2843
2843
2844
2844
// Return the broadcasted dimensions. Including broadcasts in the leading
2845
- // dimensions and broadcasts through unit dimension (i.e. dim-1) .
2845
+ // dimensions and broadcasts through unit dimension.
2846
2846
static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2847
2847
ArrayRef<int64_t > destShape) {
2848
2848
assert (destShape.size () >= srcShape.size ());
@@ -2855,7 +2855,8 @@ static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
2855
2855
}
2856
2856
2857
2857
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2858
- // with broadcast's result type and the broadcasted dimensions are the same.
2858
+ // with broadcast's result type and shape_cast only adds or removes ones in the
2859
+ // leading dimensions.
2859
2860
static LogicalResult foldBroadcastOfShapeCast (BroadcastOp broadcastOp) {
2860
2861
auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
2861
2862
if (!srcShapeCast)
@@ -2868,22 +2869,22 @@ static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
2868
2869
BroadcastableToResult::Success)
2869
2870
return failure ();
2870
2871
2871
- // Given
2872
- // ```
2873
- // %s = shape_cast(%x)
2874
- // %b = broadcast(%s)
2875
- // ```
2876
- // If we want to fold %x into %b, the broadcasted dimensions from %x to
2877
- // %b has to be the same as that of from %s to %b.
2872
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2878
2873
ArrayRef<int64_t > shapecastShape =
2879
2874
srcShapeCast.getResultVectorType ().getShape ();
2880
- ArrayRef< int64_t > srcShape = srcType. getShape ();
2881
- ArrayRef< int64_t > destShape = destType. getShape ();
2882
- BitVector origBroadcastedDims = getBroadcastedDims (shapecastShape, destShape );
2883
- BitVector newBroadcastedDims = getBroadcastedDims ( srcShape, destShape);
2884
- if (newBroadcastedDims != origBroadcastedDims )
2875
+ // Trailing dimensions should be the same if shape_cast only alters the
2876
+ // leading dimensions.
2877
+ unsigned numTrailingDims = std::min (srcShape. size (), shapecastShape. size () );
2878
+ if (! llvm::equal ( srcShape. take_back (numTrailingDims),
2879
+ shapecastShape. take_back (numTrailingDims)) )
2885
2880
return failure ();
2886
2881
2882
+ assert (all_of (srcShape.drop_back (numTrailingDims),
2883
+ [](int64_t E) { return E == 1 ; }) &&
2884
+ all_of (shapecastShape.drop_back (numTrailingDims),
2885
+ [](int64_t E) { return E == 1 ; }) &&
2886
+ " ill-formed shape_cast" );
2887
+
2887
2888
broadcastOp.getSourceMutable ().assign (srcShapeCast.getSource ());
2888
2889
return success ();
2889
2890
}
0 commit comments