Skip to content

Commit aa8154f

Browse files
jtuylskeshavvinayak01
authored andcommitted
[GPU][DT] dce unused tensor.dim ops in SpecializeExports (iree-org#21624)
Resolves: iree-org#21375. Fixes an issue in e2e llama3 with data-tiling. The SpecializeExports pass creates new tensor.dim operations when retrieving the iteration domain of an operation and leaves them around. When this operates on an encoded tensor, the subsequent MaterializeDeviceEncoding pass will fail on legalizing this tensor.dim operation as it is left around operating on an encoded tensor. We can get around this issue by performing canonicalization (with dce) at the end of SpecializeExports. I also tried adjusting SpecializeExports earlier to clean up the generated tensor.dim operations but that doesn't work as getIterationDomain doesn't give access to all of them. I created an issue and added a TODO to look into that more: iree-org#21623. Signed-off-by: Jorn Tuyls <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent b8e77bb commit aa8154f

File tree

2 files changed

+45
-34
lines changed

2 files changed

+45
-34
lines changed

compiler/src/iree/compiler/Codegen/Common/SpecializeExports.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Arith/Utils/Utils.h"
1919
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2020
#include "mlir/IR/Visitors.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2122

2223
#define DEBUG_TYPE "iree-codegen-specialize-exports"
2324

@@ -462,6 +463,17 @@ struct SpecializeExportsPass final
462463
specializeExportedFunctionByRangeAttribute(exportOp, exportedFunc, helper,
463464
ordinalSet);
464465
}
466+
467+
// TODO(#21623): We need DCE after this pass as it can leave around
468+
// `tensor.dim` operations that can mess up the next passes (e.g.
469+
// MaterializeDeviceEncoding). Ideally, we would avoid creating
470+
// those ops altogether if not needed.
471+
MLIRContext *ctx = &getContext();
472+
RewritePatternSet patterns(ctx);
473+
tensor::DimOp::getCanonicalizationPatterns(patterns, ctx);
474+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
475+
signalPassFailure();
476+
}
465477
}
466478
};
467479

compiler/src/iree/compiler/Codegen/Common/test/specialize_exports.mlir

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: iree-opt %s \
2-
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-specialize-exports, cse)))" \
2+
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-specialize-exports)))" \
33
// RUN: --split-input-file | FileCheck %s
44

55
#executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
@@ -40,18 +40,18 @@ hal.executable private @single_specialization_executable {
4040
}
4141
}
4242

43+
// Note the `CHECK-NOT: tensor.dim` which checks than unused `tensor.dim` ops are eliminated.
44+
4345
// CHECK-LABEL: hal.executable private @single_specialization_executable
4446

4547
// CHECK: hal.executable.export public @matmul_transpose_b_Dx1024x4096_f16xf16xf32 ordinal(0)
4648
// CHECK-SAME: condition(%{{.*}}: !hal.device, %[[W:.+]]: index) -> i1
47-
// CHECK-DAG: %[[TRUE:.+]] = arith.constant true
48-
// CHECK: %[[UMIN:.+]] = arith.cmpi ule, %c128, %[[W]]
49-
// CHECK: %[[CMIN:.+]] = arith.andi %[[UMIN]], %[[TRUE]]
50-
// CHECK: %[[UMAX:.+]] = arith.cmpi uge, %c4096, %[[W]]
51-
// CHECK: %[[CMAX:.+]] = arith.andi %[[UMAX]], %[[CMIN]]
52-
// CHECK: %[[UREM:.+]] = arith.remui %[[W]], %c128
53-
// CHECK: %[[UDIV:.+]] = arith.cmpi eq, %[[UREM]], %c0
54-
// CHECK: %[[CDIV:.+]] = arith.andi %[[UDIV]], %[[CMAX]]
49+
// CHECK-DAG: %[[UMAX:.+]] = arith.cmpi uge, %[[W]], %c128
50+
// CHECK-DAG: %[[UMIN:.+]] = arith.cmpi ule, %[[W]], %c4096
51+
// CHECK-DAG: %[[CMIN:.+]] = arith.andi %[[UMIN]], %[[UMAX]]
52+
// CHECK-DAG: %[[UREM:.+]] = arith.remui %[[W]], %c128
53+
// CHECK-DAG: %[[UDIV:.+]] = arith.cmpi eq, %[[UREM]], %c0
54+
// CHECK: %[[CDIV:.+]] = arith.andi %[[UDIV]], %[[CMIN]]
5555
// CHECK: hal.return %[[CDIV]]
5656
// CHECK: fallback(@matmul_transpose_b_Dx1024x4096_f16xf16xf32_0)
5757
// CHECK-SAME: count(%{{[A-Za-z0-9]*}}: !hal.device
@@ -63,8 +63,11 @@ hal.executable private @single_specialization_executable {
6363
// CHECK: builtin.module
6464
// CHECK: func.func @matmul_transpose_b_Dx1024x4096_f16xf16xf32
6565
// CHECK: util.assume.int %{{.*}}<umin = 128, umax = 4096, udiv = 128>
66+
// CHECK-NOT: tensor.dim
6667
// CHECK: func.func @matmul_transpose_b_Dx1024x4096_f16xf16xf32_0
6768
// CHECK: util.assume.int %{{.*}}<umin = 256, umax = 1048320, udiv = 256>
69+
// CHECK-NOT: tensor.dim
70+
6871

6972
// -----
7073

@@ -178,20 +181,18 @@ hal.executable private @multiple_dimension_assume {
178181

179182
// CHECK: hal.executable.export public @matmul_transpose_b_Dx1024x4096_f16xf16xf32 ordinal(0)
180183
// CHECK-SAME: condition(%{{.*}}: !hal.device, %[[W0:[A-Za-z0-9]+]]: index, %[[W1:[A-Za-z0-9]+]]: index) -> i1
181-
// CHECK: %[[TRUE:.+]] = arith.constant true
182-
// CHECK: %[[UMIN:.+]] = arith.cmpi ule, %c128, %[[W0]]
183-
// CHECK: %[[CMIN:.+]] = arith.andi %[[UMIN]], %[[TRUE]]
184-
// CHECK: %[[UMAX:.+]] = arith.cmpi uge, %c4096, %[[W0]]
185-
// CHECK: %[[CMAX:.+]] = arith.andi %[[UMAX]], %[[CMIN]]
186-
// CHECK: %[[UREM:.+]] = arith.remui %[[W0]], %c128
187-
// CHECK: %[[UDIV:.+]] = arith.cmpi eq, %[[UREM]], %c0
188-
// CHECK: %[[CDIV:.+]] = arith.andi %[[UDIV]], %[[CMAX]]
189-
// CHECK: %[[UMIN1:.+]] = arith.cmpi ule, %c128, %[[W1]]
190-
// CHECK: %[[CMIN1:.+]] = arith.andi %[[UMIN1]], %[[CDIV]]
191-
// CHECK: %[[UMAX1:.+]] = arith.cmpi uge, %c4096, %[[W1]]
192-
// CHECK: %[[CMAX1:.+]] = arith.andi %[[UMAX1]], %[[CMIN1]]
193-
// CHECK: %[[UREM1:.+]] = arith.remui %[[W1]], %c128
194-
// CHECK: %[[UDIV1:.+]] = arith.cmpi eq, %[[UREM1]], %c0
184+
// CHECK-DAG: %[[UMAX:.+]] = arith.cmpi uge, %[[W0]], %c128
185+
// CHECK-DAG: %[[UMIN:.+]] = arith.cmpi ule, %[[W0]], %c4096
186+
// CHECK-DAG: %[[CMAX:.+]] = arith.andi %[[UMIN]], %[[UMAX]]
187+
// CHECK-DAG: %[[UREM:.+]] = arith.remui %[[W0]], %c128
188+
// CHECK-DAG: %[[UDIV:.+]] = arith.cmpi eq, %[[UREM]], %c0
189+
// CHECK-DAG: %[[CDIV:.+]] = arith.andi %[[UDIV]], %[[CMAX]]
190+
// CHECK-DAG: %[[UMAX1:.+]] = arith.cmpi uge, %[[W1]], %c128
191+
// CHECK-DAG: %[[CMIN1:.+]] = arith.andi %[[UMAX1]], %[[CDIV]]
192+
// CHECK-DAG: %[[UMIN1:.+]] = arith.cmpi ule, %[[W1]], %c4096
193+
// CHECK-DAG: %[[CMAX1:.+]] = arith.andi %[[UMIN1]], %[[CMIN1]]
194+
// CHECK-DAG: %[[UREM1:.+]] = arith.remui %[[W1]], %c128
195+
// CHECK-DAG: %[[UDIV1:.+]] = arith.cmpi eq, %[[UREM1]], %c0
195196
// CHECK: %[[CDIV1:.+]] = arith.andi %[[UDIV1]], %[[CMAX1]]
196197
// CHECK: hal.return %[[CDIV1]]
197198
// CHECK: fallback(@matmul_transpose_b_Dx1024x4096_f16xf16xf32_0)
@@ -200,17 +201,15 @@ hal.executable private @multiple_dimension_assume {
200201

201202
// CHECK: hal.executable.export public @matmul_transpose_b_Dx1024x4096_f16xf16xf32_0 ordinal(1)
202203
// CHECK-SAME: condition(%{{.*}}: !hal.device, %[[W0:[A-Za-z0-9]+]]: index, %[[W1:[A-Za-z0-9]+]]: index) -> i1
203-
// CHECK: %[[TRUE:.+]] = arith.constant true
204-
// CHECK: %[[UMIN:.+]] = arith.cmpi ule, %c4096, %[[W0]]
205-
// CHECK: %[[CMIN:.+]] = arith.andi %[[UMIN]], %[[TRUE]]
206-
// CHECK: %[[UREM:.+]] = arith.remui %[[W0]], %c256
207-
// CHECK: %[[UDIV:.+]] = arith.cmpi eq, %[[UREM]], %c0
208-
// CHECK: %[[CDIV:.+]] = arith.andi %[[UDIV]], %[[CMIN]]
209-
// CHECK: %[[UMIN1:.+]] = arith.cmpi ule, %c4096, %[[W1]]
210-
// CHECK: %[[CMIN1:.+]] = arith.andi %[[UMIN1]], %[[CDIV]]
211-
// CHECK: %[[UREM1:.+]] = arith.remui %[[W1]], %c256
212-
// CHECK: %[[UDIV1:.+]] = arith.cmpi eq, %[[UREM1]], %c0
213-
// CHECK: %[[CDIV1:.+]] = arith.andi %[[UDIV1]], %[[CMIN1]]
204+
// CHECK-DAG: %[[UMAX:.+]] = arith.cmpi uge, %[[W0]], %c4096
205+
// CHECK-DAG: %[[UREM:.+]] = arith.remui %[[W0]], %c256
206+
// CHECK-DAG: %[[UDIV:.+]] = arith.cmpi eq, %[[UREM]], %c0
207+
// CHECK-DAG: %[[CDIV:.+]] = arith.andi %[[UDIV]], %[[UMAX]]
208+
// CHECK-DAG: %[[UMAX1:.+]] = arith.cmpi uge, %[[W1]], %c4096
209+
// CHECK-DAG: %[[CMIN:.+]] = arith.andi %[[UMAX1]], %[[CDIV]]
210+
// CHECK-DAG: %[[UREM1:.+]] = arith.remui %[[W1]], %c256
211+
// CHECK-DAG: %[[UDIV1:.+]] = arith.cmpi eq, %[[UREM1]], %c0
212+
// CHECK: %[[CDIV1:.+]] = arith.andi %[[UDIV1]], %[[CMIN]]
214213
// CHECK: hal.return %[[CDIV1]]
215214
// CHECK: fallback(@matmul_transpose_b_Dx1024x4096_f16xf16xf32_0_1)
216215
// CHECK-SAME: count(%{{[A-Za-z0-9]*}}: !hal.device

0 commit comments

Comments
 (0)