@@ -2842,7 +2842,7 @@ LogicalResult BroadcastOp::verify() {
28422842}
28432843
28442844// Return the broadcasted dimensions. Including broadcasts in the leading
2845- // dimensions and broadcasts through unit dimension (i.e. dim-1) .
2845+ // dimensions and broadcasts through unit dimension.
28462846static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
28472847 ArrayRef<int64_t > destShape) {
28482848 assert (destShape.size () >= srcShape.size ());
@@ -2855,7 +2855,8 @@ static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
28552855}
28562856
28572857// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2858- // with broadcast's result type and the broadcasted dimensions are the same.
2858+ // with broadcast's result type and shape_cast only adds or removes ones in the
2859+ // leading dimensions.
28592860static LogicalResult foldBroadcastOfShapeCast (BroadcastOp broadcastOp) {
28602861 auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
28612862 if (!srcShapeCast)
@@ -2868,22 +2869,22 @@ static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
28682869 BroadcastableToResult::Success)
28692870 return failure ();
28702871
2871- // Given
2872- // ```
2873- // %s = shape_cast(%x)
2874- // %b = broadcast(%s)
2875- // ```
2876- // If we want to fold %x into %b, the broadcasted dimensions from %x to
2877- // %b has to be the same as that of from %s to %b.
2872+ ArrayRef<int64_t > srcShape = srcType.getShape ();
28782873 ArrayRef<int64_t > shapecastShape =
28792874 srcShapeCast.getResultVectorType ().getShape ();
2880- ArrayRef< int64_t > srcShape = srcType. getShape ();
2881- ArrayRef< int64_t > destShape = destType. getShape ();
2882- BitVector origBroadcastedDims = getBroadcastedDims (shapecastShape, destShape );
2883- BitVector newBroadcastedDims = getBroadcastedDims ( srcShape, destShape);
2884- if (newBroadcastedDims != origBroadcastedDims )
2875+ // Trailing dimensions should be the same if shape_cast only alters the
2876+ // leading dimensions.
2877+ unsigned numTrailingDims = std::min (srcShape. size (), shapecastShape. size () );
2878+ if (! llvm::equal ( srcShape. take_back (numTrailingDims),
2879+ shapecastShape. take_back (numTrailingDims)) )
28852880 return failure ();
28862881
2882+ assert (all_of (srcShape.drop_back (numTrailingDims),
2883+ [](int64_t E) { return E == 1 ; }) &&
2884+ all_of (shapecastShape.drop_back (numTrailingDims),
2885+ [](int64_t E) { return E == 1 ; }) &&
2886+ " ill-formed shape_cast" );
2887+
28872888 broadcastOp.getSourceMutable ().assign (srcShapeCast.getSource ());
28882889 return success ();
28892890}
0 commit comments