Skip to content

Commit 76eeb36

Browse files
committed
[mlir][tensor][memref] Enhance collapse(expand(src)) canonicalization pattern.
The expand_shape op takes dynamic output value, and we need to take it into account when we compose the op. Otherwise, it fails to create the new expand_shape op. Signed-off-by: hanhanW <[email protected]>
1 parent 772eb07 commit 76eeb36

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
1515
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
1616

17+
#include "mlir/Dialect/Arith/IR/Arith.h"
1718
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1819
#include "mlir/IR/OpImplementation.h"
1920
#include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
305306
rewriter.replaceOpWithNewOp<CollapseOpTy>(
306307
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
307308
} else if (srcRank < resultRank) {
309+
// Compute the dynamic output shape for the new expand_shape op.
310+
Location loc = collapseOp.getLoc();
311+
SmallVector<OpFoldResult> origOutputShape =
312+
expandOp.getMixedOutputShape();
313+
SmallVector<OpFoldResult> newOutputShape;
314+
for (auto indices : collapseOp.getReassociationIndices()) {
315+
int64_t numStaticElems = 1;
316+
SmallVector<Value> dynamicSizes;
317+
for (auto idx : indices) {
318+
OpFoldResult size = origOutputShape[idx];
319+
if (auto maybeCst = getConstantIntValue(size)) {
320+
numStaticElems *= maybeCst.value();
321+
continue;
322+
}
323+
dynamicSizes.push_back(cast<Value>(size));
324+
}
325+
if (dynamicSizes.empty()) {
326+
newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
327+
continue;
328+
}
329+
330+
// There is at least one dynamic size, so we can intialize `result` to
331+
// the first dynamic size.
332+
Value result = dynamicSizes[0];
333+
for (auto v : llvm::drop_begin(dynamicSizes))
334+
result = rewriter.create<arith::MulIOp>(loc, result, v);
335+
if (numStaticElems != 1) {
336+
result = rewriter.create<arith::MulIOp>(
337+
loc, result,
338+
rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
339+
}
340+
newOutputShape.push_back(result);
341+
}
308342
rewriter.replaceOpWithNewOp<ExpandOpTy>(
309-
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
343+
collapseOp, resultType, expandOp.getSrc(), composedReassociation,
344+
newOutputShape);
310345
} else {
311346
// Collapses/expansions that do not change the rank are not allowed. Use
312347
// a cast instead.

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
466466

467467
// -----
468468

469+
func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
470+
%expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
471+
%collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
472+
return %collapsed : memref<8x?x?xf16>
473+
}
474+
// CHECK: func @compose_collapse_of_expand_partially_dynamic
475+
// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
476+
// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
477+
// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
478+
// CHECK-DAG: %[[C32:.+]] = arith.constant 32
479+
// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
480+
// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
481+
// CHECK-SAME: [0, 1, 2]
482+
// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
483+
// CHECK: return %[[RESULT]]
484+
485+
// -----
486+
469487
func.func @do_not_compose_collapse_of_expand_non_identity_layout(
470488
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
471489
-> memref<?xf32, strided<[?], offset: 0>> {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
12431243

12441244
// -----
12451245

1246+
func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
1247+
%expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
1248+
%collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
1249+
return %collapsed : tensor<8x?x?xf16>
1250+
}
1251+
// CHECK: func @compose_collapse_of_expand_partially_dynamic
1252+
// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
1253+
// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
1254+
// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
1255+
// CHECK-DAG: %[[C32:.+]] = arith.constant 32
1256+
// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
1257+
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
1258+
// CHECK-SAME: [0, 1, 2]
1259+
// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
1260+
// CHECK: return %[[RESULT]]
1261+
1262+
// -----
1263+
12461264
func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
12471265
-> tensor<1x1x1x1xf32> {
12481266
%0 = tensor.collapse_shape %arg0 []

0 commit comments

Comments
 (0)