Skip to content

Commit 42d8959

Browse files
committed
Move dropGivenUnitDims to Tensor Utils
1 parent cf20e80 commit 42d8959

File tree

3 files changed

+40
-34
lines changed

3 files changed

+40
-34
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ FailureOr<RankedTensorType>
4343
computeTransposedType(RankedTensorType rankedTensorType,
4444
ArrayRef<int64_t> transposeVector);
4545

46+
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
47+
/// `from`.
48+
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
49+
const llvm::SmallBitVector &dropDims);
50+
4651
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
4752
/// source tensor or inserts the source tensor into a destination tensor with
4853
/// the same shape.

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -236,39 +236,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
236236
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
237237
}
238238

239-
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
240-
/// `from`.
241-
static tensor::CollapseShapeOp
242-
dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
243-
const llvm::SmallBitVector &dropDims) {
244-
auto fromType = cast<ShapedType>(from.getType());
245-
int64_t rank = fromType.getRank();
246-
assert(rank == static_cast<int64_t>(dropDims.size()) &&
247-
"dropDims dimension does not match from tensor rank");
248-
assert(llvm::all_of(
249-
dropDims.set_bits(),
250-
[&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
251-
"Dropping non unit dimension");
252-
// Computed reassociation map for the corresponding tensor.collapse_shape.
253-
SmallVector<ReassociationIndices, 2> reassocMaps;
254-
// Current reassociation group to add dropped dimension to.
255-
256-
int64_t nextDimToGroup = 0;
257-
llvm::SmallBitVector keptDims(dropDims);
258-
keptDims.flip();
259-
int64_t lastSetBit = keptDims.find_last();
260-
for (int64_t setBit : keptDims.set_bits()) {
261-
// Group consecutive dropped dimension with the next non-dropped dimension.
262-
// If this is the last set dimension, also group all subsequent dropped
263-
// dimension, if any.
264-
int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
265-
auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
266-
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
267-
nextDimToGroup = setBit + 1;
268-
}
269-
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
270-
}
271-
272239
FailureOr<FusionInfo>
273240
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
274241
OpOperand &consumerOpOperand) {
@@ -312,7 +279,8 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
312279
// Rank-reduction occurred as part of the extract_slice.
313280
if (cast<ShapedType>(consumerType).getRank() !=
314281
cast<ShapedType>(def.getType()).getRank())
315-
def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
282+
def =
283+
tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
316284
// Canonicalizations are not guaranteed to have happened before constructing
317285
// `fusedProducer`. In the tensor case this can result in temporary type
318286
// mismatches. Insert a `tensor.cast` op to propagate the transformation

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,39 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
9494
return transposedTensorType;
9595
}
9696

97+
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
98+
/// `from`.
99+
CollapseShapeOp
100+
mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
101+
const llvm::SmallBitVector &dropDims) {
102+
auto fromType = cast<ShapedType>(from.getType());
103+
int64_t rank = fromType.getRank();
104+
assert(rank == static_cast<int64_t>(dropDims.size()) &&
105+
"dropDims dimension does not match from tensor rank");
106+
assert(llvm::all_of(
107+
dropDims.set_bits(),
108+
[&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
109+
"Dropping non unit dimension");
110+
// Computed reassociation map for the corresponding tensor.collapse_shape.
111+
SmallVector<ReassociationIndices, 2> reassocMaps;
112+
// Current reassociation group to add dropped dimension to.
113+
114+
int64_t nextDimToGroup = 0;
115+
llvm::SmallBitVector keptDims(dropDims);
116+
keptDims.flip();
117+
int64_t lastSetBit = keptDims.find_last();
118+
for (int64_t setBit : keptDims.set_bits()) {
119+
// Group consecutive dropped dimension with the next non-dropped dimension.
120+
// If this is the last set dimension, also group all subsequent dropped
121+
// dimension, if any.
122+
int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
123+
auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
124+
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
125+
nextDimToGroup = setBit + 1;
126+
}
127+
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
128+
}
129+
97130
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
98131
llvm::SmallBitVector droppedDims = op.getDroppedDims();
99132
int64_t srcDim = 0;

0 commit comments

Comments
 (0)