Skip to content

Commit fcf79e5

Browse files
authored
[MLIR] Improve in-place folding to iterate until fixed-point (#160615)
When executed in the context of canonicalization, the folders are invoked in a fixed-point iterative process. However in the context of an API like `createOrFold()` or in DialectConversion for example, we expect a "one-shot" call to fold to be as "folded" as possible. However, even when folders themselves are indempotent, folders on a given operation interact with each other. For example: ``` // X = 0 + Y %X = arith.addi %c_0, %Y : i32 ``` should fold to %Y, but the process actually involves first the folder provided by the IsCommutative trait to move the constant to the right. However this happens after attempting to fold the operation and the operation folder isn't attempt again after applying the trait folder. This commit makes sure we iterate until fixed point on folder applications. Fixes #159844
1 parent e276000 commit fcf79e5

File tree

9 files changed

+84
-42
lines changed

9 files changed

+84
-42
lines changed

mlir/include/mlir/Transforms/FoldUtils.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ class OperationFolder {
4040
/// deduplicated constants. If successful, replaces `op`'s uses with
4141
/// folded results, and returns success. If the op was completely folded it is
4242
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
43-
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr);
43+
/// On success() and when in-place, the folder is invoked until
44+
/// `maxIterations` is reached (default INT_MAX).
45+
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr,
46+
int maxIterations = INT_MAX);
4447

4548
/// Tries to fold a pre-existing constant operation. `constValue` represents
4649
/// the value of the constant, and can be optionally passed if the value is
@@ -82,7 +85,10 @@ class OperationFolder {
8285

8386
/// Tries to perform folding on the given `op`. If successful, populates
8487
/// `results` with the results of the folding.
85-
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results);
88+
/// On success() and when in-place, the folder is invoked until
89+
/// `maxIterations` is reached (default INT_MAX).
90+
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results,
91+
int maxIterations = INT_MAX);
8692

8793
/// Try to process a set of fold results. Populates `results` on success,
8894
/// otherwise leaves it unchanged.

mlir/lib/IR/Builders.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/IRMapping.h"
1515
#include "mlir/IR/Matchers.h"
1616
#include "llvm/ADT/SmallVectorExtras.h"
17+
#include "llvm/Support/DebugLog.h"
1718

1819
using namespace mlir;
1920

@@ -486,9 +487,25 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
486487

487488
// Try to fold the operation.
488489
SmallVector<OpFoldResult, 4> foldResults;
490+
LDBG() << "Trying to fold: "
491+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
492+
if (op->getName().getStringRef() == "vector.extract") {
493+
Operation *parent = op->getParentOp();
494+
while (parent && parent->getName().getStringRef() != "spirv.func")
495+
parent = parent->getParentOp();
496+
if (parent)
497+
parent->dump();
498+
}
489499
if (failed(op->fold(foldResults)))
490500
return cleanupFailure();
491501

502+
int count = 0;
503+
do {
504+
LDBG() << "Folded in place #" << count
505+
<< " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
506+
count++;
507+
} while (foldResults.empty() && succeeded(op->fold(foldResults)));
508+
492509
// An in-place fold does not require generation of any constants.
493510
if (foldResults.empty())
494511
return success();

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/Builders.h"
1717
#include "mlir/IR/Matchers.h"
1818
#include "mlir/IR/Operation.h"
19+
#include "llvm/Support/DebugLog.h"
1920

2021
using namespace mlir;
2122

@@ -67,7 +68,8 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
6768
// OperationFolder
6869
//===----------------------------------------------------------------------===//
6970

70-
LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
71+
LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate,
72+
int maxIterations) {
7173
if (inPlaceUpdate)
7274
*inPlaceUpdate = false;
7375

@@ -86,7 +88,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
8688

8789
// Try to fold the operation.
8890
SmallVector<Value, 8> results;
89-
if (failed(tryToFold(op, results)))
91+
if (failed(tryToFold(op, results, maxIterations)))
9092
return failure();
9193

9294
// Check to see if the operation was just updated in place.
@@ -224,10 +226,19 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
224226
/// Tries to perform folding on the given `op`. If successful, populates
225227
/// `results` with the results of the folding.
226228
LogicalResult OperationFolder::tryToFold(Operation *op,
227-
SmallVectorImpl<Value> &results) {
229+
SmallVectorImpl<Value> &results,
230+
int maxIterations) {
228231
SmallVector<OpFoldResult, 8> foldResults;
229-
if (failed(op->fold(foldResults)) ||
230-
failed(processFoldResults(op, results, foldResults)))
232+
if (failed(op->fold(foldResults)))
233+
return failure();
234+
int count = 1;
235+
do {
236+
LDBG() << "Folded in place #" << count
237+
<< " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
238+
} while (count++ < maxIterations && foldResults.empty() &&
239+
succeeded(op->fold(foldResults)));
240+
241+
if (failed(processFoldResults(op, results, foldResults)))
231242
return failure();
232243
return success();
233244
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Test with the default (one application of the folder) and then with 2 iterations.
2+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-single-fold))" | FileCheck %s --check-prefixes=CHECK,CHECK-ONE
3+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-single-fold{max-iterations=2}))" | FileCheck %s --check-prefixes=CHECK,CHECK-TWO
4+
5+
6+
// Folding entirely this requires to move the constant to the right
7+
// before invoking the op-specific folder.
8+
// With one iteration, we just push the constant to the right.
9+
// With a second iteration, we actually fold the "add" (x+0->x)
10+
// CHECK: func @recurse_fold_traits(%[[ARG0:.*]]: i32)
11+
func.func @recurse_fold_traits(%arg0 : i32) -> i32 {
12+
%cst0 = arith.constant 0 : i32
13+
// CHECK-ONE: %[[ADD:.*]] = arith.addi %[[ARG0]],
14+
%res = arith.addi %cst0, %arg0 : i32
15+
// CHECK-ONE: return %[[ADD]] : i32
16+
// CHECK-TWO: return %[[ARG0]] : i32
17+
return %res : i32
18+
}

mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ gpu.module @test {
77
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
88
//CHECK: [[c32:%.+]] = arith.constant 32 : index
99
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
10-
//CHECK: [[c0:%.+]] = arith.constant 0 : index
11-
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
1210
//CHECK: [[c128:%.+]] = arith.constant 128 : index
13-
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
11+
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
1412
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
1513
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
1614
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
@@ -23,10 +21,8 @@ gpu.module @test {
2321
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
2422
//CHECK: [[c32:%.+]] = arith.constant 32 : index
2523
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
26-
//CHECK: [[c0:%.+]] = arith.constant 0 : index
27-
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
2824
//CHECK: [[c128:%.+]] = arith.constant 128 : index
29-
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
25+
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
3026
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
3127
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
3228
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,10 @@ gpu.module @test_round_robin_assignment {
2727
//CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
2828
//CHECK: [[C0:%.+]] = arith.constant 0 : index
2929
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
30-
//CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
31-
//CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
3230
//CHECK: [[C128:%.+]] = arith.constant 128 : index
33-
//CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
31+
//CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
3432
//CHECK: [[C64_2:%.+]] = arith.constant 64 : index
35-
//CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
33+
//CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
3634
//CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
3735
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
3836
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,10 @@ gpu.module @test_distribution {
325325
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
326326
//CHECK: [[c0:%.+]] = arith.constant 0 : index
327327
//CHECK: [[c0_1:%.+]] = arith.constant 0 : index
328-
//CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
329-
//CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
330328
//CHECK: [[c64:%.+]] = arith.constant 64 : index
331-
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
329+
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
332330
//CHECK: [[c128:%.+]] = arith.constant 128 : index
333-
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
331+
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
334332
//CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
335333
%0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
336334
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
@@ -349,13 +347,11 @@ gpu.module @test_distribution {
349347
//CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
350348
//CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
351349
//CHECK: [[c32:%.+]] = arith.constant 32 : index
352-
//CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
350+
//CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
353351
//CHECK: [[c32_1:%.+]] = arith.constant 32 : index
354-
//CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
352+
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
355353
//CHECK: [[c0:%.+]] = arith.constant 0 : index
356354
//CHECK: [[c0_2:%.+]] = arith.constant 0 : index
357-
//CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
358-
//CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
359355
//CHECK: [[c64:%.+]] = arith.constant 64 : index
360356
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
361357
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -412,11 +408,10 @@ gpu.module @test_distribution {
412408
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
413409
//CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
414410
//CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
415-
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
411+
//CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]]
416412
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
417-
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
418413
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
419-
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
414+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
420415
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
421416
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
422417
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
@@ -430,9 +425,8 @@ gpu.module @test_distribution {
430425
//CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
431426
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
432427
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
433-
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
434428
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
435-
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
429+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
436430
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
437431
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
438432
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@ gpu.module @test_1_1_assignment {
1414
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
1515
//CHECK: [[C0:%.+]] = arith.constant 0 : index
1616
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
17-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
18-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
1917
//CHECK: [[C256:%.+]] = arith.constant 256 : index
20-
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
18+
//CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]]
2119
//CHECK: [[C128:%.+]] = arith.constant 128 : index
22-
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
20+
//CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]]
2321
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
2422
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
2523
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -37,17 +35,13 @@ gpu.module @test_1_1_assignment {
3735
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
3836
//CHECK: [[C0:%.+]] = arith.constant 0 : index
3937
//CHECK: [[C0_2:%.+]] = arith.constant 0 : index
40-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
41-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_2]] : index
4238
//CHECK: [[C256:%.+]] = arith.constant 256 : index
43-
//CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
39+
//CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]]
4440
//CHECK: [[C128:%.+]] = arith.constant 128 : index
45-
//CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
41+
//CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]]
4642
//CHECK: [[C0_3:%.+]] = arith.constant 0 : index
47-
//CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_3]]
4843
//CHECK: [[C0_4:%.+]] = arith.constant 0 : index
49-
//CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_4]]
50-
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[Y]], [[X]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
44+
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
5145
%tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
5246
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
5347
gpu.return

mlir/test/lib/Transforms/TestSingleFold.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
2626
public RewriterBase::Listener {
2727
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)
2828

29+
TestSingleFold() = default;
30+
TestSingleFold(const TestSingleFold &pass) : PassWrapper(pass) {}
31+
2932
StringRef getArgument() const final { return "test-single-fold"; }
3033
StringRef getDescription() const final {
3134
return "Test single-pass operation folding and dead constant elimination";
@@ -45,13 +48,18 @@ struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
4548
if (it != existingConstants.end())
4649
existingConstants.erase(it);
4750
}
51+
52+
Option<int> maxIterations{*this, "max-iterations",
53+
llvm::cl::desc("Max iterations in the tryToFold"),
54+
llvm::cl::init(1)};
4855
};
4956
} // namespace
5057

5158
void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
5259
// Attempt to fold the specified operation, including handling unused or
5360
// duplicated constants.
54-
(void)helper.tryToFold(op);
61+
bool inPlaceUpdate = false;
62+
(void)helper.tryToFold(op, &inPlaceUpdate, maxIterations);
5563
}
5664

5765
void TestSingleFold::runOnOperation() {

0 commit comments

Comments
 (0)