Skip to content

Commit c9ff839

Browse files
Groverksskuhar
andauthored
[mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims (#118208)
This pr fixes how iteration domain of linalg.generic is collapsed when fusing with tensor.expand_shape. Previously, the output_shape for tensor.expand shape was infered, which doesn't always work except some special cases. This patch makes the logic explicitly set the bounds of the new collapsed iteration domain, because we already know them. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 37d0f20 commit c9ff839

File tree

3 files changed

+67
-34
lines changed

3 files changed

+67
-34
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,10 +1548,9 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
15481548

15491549
/// Modify the `linalg.index` operations in the original generic op, to its
15501550
/// value in the collapsed operation.
1551-
void generateCollapsedIndexingRegion(Location loc, Block *block,
1552-
const CollapsingInfo &collapsingInfo,
1553-
ValueRange loopRange,
1554-
RewriterBase &rewriter) {
1551+
static void generateCollapsedIndexingRegion(
1552+
Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1553+
ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
15551554
OpBuilder::InsertionGuard g(rewriter);
15561555
rewriter.setInsertionPointToStart(block);
15571556

@@ -1572,10 +1571,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
15721571
Value newIndexVal =
15731572
rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
15741573
for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1574+
Value loopDim =
1575+
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
15751576
indexReplacementVals[dim] =
1576-
rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
1577+
rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
15771578
newIndexVal =
1578-
rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
1579+
rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
15791580
}
15801581
indexReplacementVals[foldedDims.value().front()] = newIndexVal;
15811582
}
@@ -1722,14 +1723,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
17221723
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
17231724

17241725
Location loc = op->getLoc();
1726+
SmallVector<OpFoldResult> loopBound =
1727+
llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1728+
17251729
if (collapsedOp.hasIndexSemantics()) {
17261730
// Collect the loop range of the generic op.
17271731
OpBuilder::InsertionGuard g(rewriter);
17281732
rewriter.setInsertionPoint(collapsedOp);
1729-
SmallVector<Value> loopBound =
1730-
llvm::map_to_vector(loopRanges, [&](Range range) {
1731-
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1732-
});
17331733
generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
17341734
collapsingInfo, loopBound, rewriter);
17351735
}
@@ -1747,15 +1747,22 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
17471747
op.getIndexingMapMatchingResult(originalResult.value());
17481748
SmallVector<ReassociationIndices> reassociation =
17491749
getOperandReassociation(indexingMap, collapsingInfo);
1750+
assert(
1751+
indexingMap.isProjectedPermutation() &&
1752+
"Expected indexing map to be a projected permutation for collapsing");
1753+
SmallVector<OpFoldResult> resultShape =
1754+
applyPermutationMap(indexingMap, ArrayRef(loopBound));
17501755
Value result;
17511756
if (isa<MemRefType>(collapsedOpResult.getType())) {
17521757
MemRefType expandShapeResultType = MemRefType::get(
17531758
originalResultType.getShape(), originalResultType.getElementType());
17541759
result = rewriter.create<memref::ExpandShapeOp>(
1755-
loc, expandShapeResultType, collapsedOpResult, reassociation);
1760+
loc, expandShapeResultType, collapsedOpResult, reassociation,
1761+
resultShape);
17561762
} else {
17571763
result = rewriter.create<tensor::ExpandShapeOp>(
1758-
loc, originalResultType, collapsedOpResult, reassociation);
1764+
loc, originalResultType, collapsedOpResult, reassociation,
1765+
resultShape);
17591766
}
17601767
results.push_back(result);
17611768
} else {

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
225225

226226
// -----
227227

228+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
229+
func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
230+
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
231+
%init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
232+
%1 = linalg.generic {
233+
indexing_maps = [#map0, #map0],
234+
iterator_types = ["parallel", "parallel"]}
235+
ins(%0 : tensor<?x?xf32>)
236+
outs(%init : tensor<?x?xf32>) {
237+
^bb0(%b0 : f32, %b1 : f32):
238+
%out = arith.negf %b0 : f32
239+
linalg.yield %out : f32
240+
} -> tensor<?x?xf32>
241+
return %1 : tensor<?x?xf32>
242+
}
243+
244+
// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
245+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
246+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
247+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
248+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
249+
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
250+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
251+
// CHECK: %[[OUT:.+]] = linalg.generic
252+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
253+
// CHECK-SAME: outs(%{{.*}} : tensor<?xf32>)
254+
// CHECK: %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
255+
// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]]]
256+
// CHECK: return %[[EXPANDED_1]]
257+
258+
// -----
259+
228260
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
229261
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
230262
func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
425457
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
426458
// CHECK: func @fuse_only_one_reassociation
427459
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
428-
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
429460
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
430-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
461+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
431462
// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
463+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
464+
// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
432465
// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
433466
// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
434467
// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
437470
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
438471
// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
439472
// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] :
440-
// CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
441-
// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
442-
// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
443-
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
473+
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
444474
// CHECK: return %[[EXPANDED_3]]
445475

446476
// -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
475505
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
476506
// CHECK: func @fold_non_consecutive_dims(
477507
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
478-
// CHECK: %[[C1:.+]] = arith.constant 1 : index
479-
// CHECK: %[[C4:.+]] = arith.constant 4 : index
480-
// CHECK: %[[C8:.+]] = arith.constant 8 : index
481-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
482-
// CHECK: %[[C2:.+]] = arith.constant 2 : index
508+
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
509+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
510+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
511+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
483512
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
484-
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
485-
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
513+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
514+
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
486515
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
516+
// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
517+
// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
487518
// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
488519
// CHECK: %[[GENERIC:.+]] = linalg.generic
489520
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
502533
// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
503534
// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]]
504535
// CHECK: linalg.yield %[[T7]]
505-
// CHECK: %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
506-
// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
507-
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
508-
// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index
509-
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
536+
// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
510537
// CHECK: return %[[EXPANDED_3]]
511538

512539
// -----

mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55

66
// CHECK-LABEL: func @reshape
77
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
8-
// CHECK: %[[C112:.*]] = arith.constant 112 : index
98
// CHECK: %[[C0:.*]] = arith.constant 0 : index
9+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
10+
// CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
1011
// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
1112
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
1213
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
1314
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
14-
// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
15-
// CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index
16-
// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
15+
// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[DIM]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
1716
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
1817
func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
1918
%0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]

0 commit comments

Comments
 (0)