Skip to content

Commit 892fdc9

Browse files
author
Mahesh Ravishankar
committed
[mlir][Linalg] Generalize the logic to compute reassociation maps
while folding tensor_reshape op. While folding reshapes that introduce unit extent dims, the logic to compute the reassociation maps can be generalized to handle some corner cases, for example, when the folded shape still has unit-extent dims but corresponds to folded unit extent dims of the expanded shape. Differential Revision: https://reviews.llvm.org/D88521
1 parent 3a7487f commit 892fdc9

File tree

2 files changed

+58
-45
lines changed

2 files changed

+58
-45
lines changed

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
403403
srcType.getRank() < dstType.getRank() ||
404404
parentSrcType.getRank() == dstType.getRank())
405405
return failure();
406+
406407
// Check if the result tensor_reshape after folding the reshapeOp and
407408
// parentReshapeOp are combined.
408409
// If the final tensor_reshape is folding, the parentReshapeOp is
409410
// introducing unit-dims, and the reshapeOp does an actual reshape.
410-
// If the final tensor_reshape op is expanding, the reshapeOp is introducing
411-
// unit-dims, and the parentReshapeOp does an actual reshape.
411+
// If the final tensor_reshape op is expanding, the reshapeOp is
412+
// introducing unit-dims, and the parentReshapeOp does an actual reshape.
412413
bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
413-
auto reassociationMaps = isFoldingPattern
414-
? reshapeOp.getReassociationMaps()
415-
: parentReshapeOp.getReassociationMaps();
416-
DenseSet<unsigned> conservedDimensions;
417-
for (auto &map : reassociationMaps) {
418-
if (map.getNumResults() == 1) {
419-
conservedDimensions.insert(
420-
map.getResult(0).cast<AffineDimExpr>().getPosition());
421-
}
422-
}
423-
424-
// Find positions at which the unit-dims exist.
425-
int64_t nonUnitDimPos = 0;
426-
DenseMap<unsigned, unsigned> nonUnitSrcDims;
427-
ArrayRef<int64_t> nonUnitShape =
414+
ArrayRef<int64_t> expandedShape =
428415
isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
429-
for (auto shape : enumerate(srcType.getShape())) {
430-
// Case 1 : It is a conserved dimension.
431-
if (conservedDimensions.count(shape.index())) {
432-
nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
433-
continue;
416+
ArrayRef<int64_t> foldedShape =
417+
isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
418+
419+
unsigned expandedDim = 0, foldedDim = 0;
420+
SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
421+
foldedShape.size());
422+
while (expandedDim < expandedShape.size() &&
423+
foldedDim < foldedShape.size()) {
424+
int64_t dstSize = foldedShape[foldedDim];
425+
int64_t srcSize = expandedShape[expandedDim];
426+
while (srcSize < dstSize && expandedDim < expandedShape.size()) {
427+
reassociationExprs[foldedDim].push_back(
428+
rewriter.getAffineDimExpr(expandedDim++));
429+
srcSize *= expandedShape[expandedDim];
434430
}
435-
// Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
436-
if (shape.value() == 1)
437-
continue;
438-
// Case 3 : Dimensions match, treat it as a non-unit src dim.
439-
if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
440-
nonUnitShape[nonUnitDimPos] == shape.value()) {
441-
nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
442-
continue;
431+
if (srcSize == dstSize) {
432+
reassociationExprs[foldedDim].push_back(
433+
rewriter.getAffineDimExpr(expandedDim++));
434+
// If the next dim in foldedShape is not 1, treat subsequent dims in
435+
// expandedShape which are 1 to be collapsed.
436+
if (foldedDim == foldedShape.size() - 1 ||
437+
foldedShape[foldedDim + 1] != 1) {
438+
while (expandedDim < expandedShape.size() &&
439+
expandedShape[expandedDim] == 1) {
440+
reassociationExprs[foldedDim].push_back(
441+
rewriter.getAffineDimExpr(expandedDim++));
442+
}
443+
}
444+
} else {
445+
return failure();
443446
}
444-
return failure();
447+
foldedDim++;
445448
}
449+
if (expandedDim != expandedShape.size())
450+
return failure();
446451

447-
// Compute reassociation maps for the final operation. Use the reassociation
448-
// maps that is actually doing a reshape (and not just introducing
449-
// unit-dims). From these maps, prune the unit-extent dimensions.
450-
for (AffineMap &map : reassociationMaps) {
451-
SmallVector<AffineExpr, 4> exprs;
452-
exprs.reserve(nonUnitSrcDims.size());
453-
for (auto result : map.getResults()) {
454-
unsigned dim = result.cast<AffineDimExpr>().getPosition();
455-
if (nonUnitSrcDims.count(dim))
456-
exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
457-
}
458-
map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
459-
rewriter.getContext());
460-
}
452+
SmallVector<AffineMap, 4> reassociationMaps =
453+
llvm::to_vector<4>(llvm::map_range(
454+
reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
455+
return AffineMap::get(expandedShape.size(), 0, exprs,
456+
rewriter.getContext());
457+
}));
461458
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
462459
reshapeOp, dstType, parentReshapeOp.src(),
463460
rewriter.getAffineMapArrayAttr(reassociationMaps));

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,19 @@ func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
240240
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
241241
return %1 : tensor<4x512x1x512x4xf32>
242242
}
243+
244+
// -----
245+
246+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
247+
// CHECK: func @fold_reshape
248+
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]
249+
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
250+
func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
251+
{
252+
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32>
253+
%1 = linalg.tensor_reshape %0
254+
[affine_map<(d0, d1, d2) -> (d0)>,
255+
affine_map<(d0, d1, d2) -> (d1, d2)>
256+
] : tensor<2x1x1xf32> into tensor<2x1xf32>
257+
return %1 : tensor<2x1xf32>
258+
}

0 commit comments

Comments
 (0)