@@ -236,39 +236,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
236236 return fuseProducerOfTensor (b, producerOpResult, consumerOpOperand);
237237}
238238
239- // / Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
240- // / `from`.
241- static tensor::CollapseShapeOp
242- dropGivenUnitDims (OpBuilder &b, Location loc, Value from,
243- const llvm::SmallBitVector &dropDims) {
244- auto fromType = cast<ShapedType>(from.getType ());
245- int64_t rank = fromType.getRank ();
246- assert (rank == static_cast <int64_t >(dropDims.size ()) &&
247- " dropDims dimension does not match from tensor rank" );
248- assert (llvm::all_of (
249- dropDims.set_bits (),
250- [&](unsigned dim) { return fromType.getShape ()[dim] == 1 ; }) &&
251- " Dropping non unit dimension" );
252- // Computed reassociation map for the corresponding tensor.collapse_shape.
253- SmallVector<ReassociationIndices, 2 > reassocMaps;
254- // Current reassociation group to add dropped dimension to.
255-
256- int64_t nextDimToGroup = 0 ;
257- llvm::SmallBitVector keptDims (dropDims);
258- keptDims.flip ();
259- int64_t lastSetBit = keptDims.find_last ();
260- for (int64_t setBit : keptDims.set_bits ()) {
261- // Group consecutive dropped dimension with the next non-dropped dimension.
262- // If this is the last set dimension, also group all subsequent dropped
263- // dimension, if any.
264- int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
265- auto seq = llvm::seq_inclusive (nextDimToGroup, upTo);
266- reassocMaps.emplace_back (llvm::make_range (seq.begin (), seq.end ()));
267- nextDimToGroup = setBit + 1 ;
268- }
269- return b.create <tensor::CollapseShapeOp>(loc, from, reassocMaps);
270- }
271-
272239FailureOr<FusionInfo>
273240mlir::linalg::fuseProducerOfTensor (OpBuilder &b, OpResult producerOpResult,
274241 OpOperand &consumerOpOperand) {
@@ -312,7 +279,8 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
312279 // Rank-reduction occurred as part of the extract_slice.
313280 if (cast<ShapedType>(consumerType).getRank () !=
314281 cast<ShapedType>(def.getType ()).getRank ())
315- def = dropGivenUnitDims (b, fusedProducer.getLoc (), def, droppedDims);
282+ def =
283+ tensor::dropGivenUnitDims (b, fusedProducer.getLoc (), def, droppedDims);
316284 // Canonicalizations are not guaranteed to have happened before constructing
317285 // `fusedProducer`. In the tensor case this can result in temporary type
318286 // mismatches. Insert a `tensor.cast` op to propagate the transformation
0 commit comments