Skip to content

Commit 5a97523

Browse files
authored
[GPU] Disable insert/extract slice lowering from pack/unpack by default (iree-org#19590)
This PR is a follow-up to llvm/llvm-project#117340. It disables `lowerPadLikeWithInsertSlice` and `lowerUnpadLikeWithExtractSlice` so `insertslice` or `extractslice` won't appear when high dimensions are unit dimensions. --------- Signed-off-by: jerryyin <[email protected]>
1 parent 1e935c4 commit 5a97523

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
5656
if (controlFn && failed(controlFn.value()(op))) {
5757
return failure();
5858
}
59-
FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
59+
FailureOr<linalg::LowerPackResult> res =
60+
linalg::lowerPack(rewriter, op, /*lowerPadLikeWithInsertSlice=*/false);
6061
if (failed(res)) {
6162
return rewriter.notifyMatchFailure(
6263
op, "cannot lower to pad + expand + transpose");
@@ -83,8 +84,8 @@ struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
8384
if (controlFn && failed(controlFn.value()(op))) {
8485
return failure();
8586
}
86-
FailureOr<linalg::LowerUnPackOpResult> res =
87-
linalg::lowerUnPack(rewriter, op);
87+
FailureOr<linalg::LowerUnPackOpResult> res = linalg::lowerUnPack(
88+
rewriter, op, /*lowerUnpadLikeWithExtractSlice=*/false);
8889
if (failed(res)) {
8990
return rewriter.notifyMatchFailure(
9091
op, "cannot lower to empty + transpose + reshape + extract_slice");

compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3030
// CHECK-ALL-SAME: %[[PAD_VAL:[A-Za-z0-9]+]]:
3131
// CHECK-ALL: %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
3232
// CHECK-ALL: tensor.yield %[[PAD_VAL]]
33-
// CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
34-
// CHECK-ALL: return %[[INSERT]]
33+
34+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
35+
// CHECK: return %[[INSERT]]
36+
37+
// CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] output_shape [1, 8, 1, 2] : tensor<8x2xf32> into tensor<1x8x1x2xf32>
38+
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x8x1x2xf32>) outs(%[[OUT]] : tensor<1x1x8x2xf32>) permutation = [0, 2, 1, 3]
39+
// CHECK-RESHAPE: return %[[TRANS]]
3540

3641
// -----
3742

@@ -42,8 +47,13 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
4247
// CHECK-ALL-LABEL: func.func @simple_NC_to_CNnc
4348
// CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]:
4449
// CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]:
45-
// CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
46-
// CHECK-ALL: return %[[INSERT]]
50+
51+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
52+
// CHECK: return %[[INSERT]]
53+
54+
// CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0, 1], [2, 3]] output_shape [1, 32, 1, 8] : tensor<32x8xf32> into tensor<1x32x1x8xf32>
55+
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x32x1x8xf32>) outs(%[[OUT]] : tensor<1x1x32x8xf32>) permutation = [2, 0, 1, 3]
56+
// CHECK-RESHAPE: return %[[TRANS]]
4757

4858
// -----
4959

@@ -132,8 +142,11 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
132142
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
133143
// CHECK: %[[RES:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
134144

135-
// CHECK-RESHAPE: %[[RES:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1]
136-
145+
// CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x2xf32>
146+
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x8x2xf32>) outs(%[[EMPTY]] : tensor<1x8x1x2xf32>) permutation = [0, 2, 1, 3]
147+
// CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape
148+
// CHECK-RESHAPE: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]]
149+
// CHECK-RESHAPE: %[[RES:.+]] = linalg.copy ins(%[[SLICE]] : tensor<5x1xf32>) outs(%[[OUT]] : tensor<5x1xf32>) -> tensor<5x1xf32>
137150
// CHECK-ALL: return %[[RES:.+]]
138151

139152
// -----
@@ -145,8 +158,13 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
145158
// CHECK-ALL-LABEL: func.func @simple_CNnc_to_NC
146159
// CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]:
147160
// CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]:
148-
// CHECK-ALL: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
149-
// CHECK-ALL: return %[[TILE]]
161+
// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
162+
163+
// CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x1x8xf32>
164+
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x32x8xf32>) outs(%[[EMPTY]] : tensor<1x32x1x8xf32>) permutation = [1, 2, 0, 3]
165+
// CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape
166+
// CHECK-RESHAPE: %[[RESULT:.+]] = linalg.copy ins(%[[COLLAPSE]] : tensor<32x8xf32>) outs(%[[OUT]] : tensor<32x8xf32>) -> tensor<32x8xf32>
167+
// CHECK-ALL: return %[[RESULT]]
150168

151169
// -----
152170

0 commit comments

Comments
 (0)