Skip to content

Commit 82ebb7b

Browse files
committed
Revert "Revert "[MLIR] Improve in-place folding to iterate until fixed-point (llvm#160615)""
This reverts commit 2bd7a5f.
1 parent 90a6758 commit 82ebb7b

File tree

9 files changed

+77
-42
lines changed

9 files changed

+77
-42
lines changed

mlir/include/mlir/Transforms/FoldUtils.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ class OperationFolder {
4040
/// deduplicated constants. If successful, replaces `op`'s uses with
4141
/// folded results, and returns success. If the op was completely folded it is
4242
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
43-
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr);
43+
/// On success() and when in-place, the folder is invoked until
44+
/// `maxIterations` is reached (default INT_MAX).
45+
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr,
46+
int maxIterations = INT_MAX);
4447

4548
/// Tries to fold a pre-existing constant operation. `constValue` represents
4649
/// the value of the constant, and can be optionally passed if the value is
@@ -82,7 +85,10 @@ class OperationFolder {
8285

8386
/// Tries to perform folding on the given `op`. If successful, populates
8487
/// `results` with the results of the folding.
85-
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results);
88+
/// On success() and when in-place, the folder is invoked until
89+
/// `maxIterations` is reached (default INT_MAX).
90+
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results,
91+
int maxIterations = INT_MAX);
8692

8793
/// Try to process a set of fold results. Populates `results` on success,
8894
/// otherwise leaves it unchanged.

mlir/lib/IR/Builders.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/IRMapping.h"
1515
#include "mlir/IR/Matchers.h"
1616
#include "llvm/ADT/SmallVectorExtras.h"
17+
#include "llvm/Support/DebugLog.h"
1718

1819
using namespace mlir;
1920

@@ -486,9 +487,18 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
486487

487488
// Try to fold the operation.
488489
SmallVector<OpFoldResult, 4> foldResults;
490+
LDBG() << "Trying to fold: "
491+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
489492
if (failed(op->fold(foldResults)))
490493
return cleanupFailure();
491494

495+
int count = 0;
496+
do {
497+
LDBG() << "Folded in place #" << count
498+
<< " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
499+
count++;
500+
} while (foldResults.empty() && succeeded(op->fold(foldResults)));
501+
492502
// An in-place fold does not require generation of any constants.
493503
if (foldResults.empty())
494504
return success();

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/Builders.h"
1717
#include "mlir/IR/Matchers.h"
1818
#include "mlir/IR/Operation.h"
19+
#include "llvm/Support/DebugLog.h"
1920

2021
using namespace mlir;
2122

@@ -67,7 +68,8 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
6768
// OperationFolder
6869
//===----------------------------------------------------------------------===//
6970

70-
LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
71+
LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate,
72+
int maxIterations) {
7173
if (inPlaceUpdate)
7274
*inPlaceUpdate = false;
7375

@@ -86,7 +88,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
8688

8789
// Try to fold the operation.
8890
SmallVector<Value, 8> results;
89-
if (failed(tryToFold(op, results)))
91+
if (failed(tryToFold(op, results, maxIterations)))
9092
return failure();
9193

9294
// Check to see if the operation was just updated in place.
@@ -224,10 +226,19 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
224226
/// Tries to perform folding on the given `op`. If successful, populates
225227
/// `results` with the results of the folding.
226228
LogicalResult OperationFolder::tryToFold(Operation *op,
227-
SmallVectorImpl<Value> &results) {
229+
SmallVectorImpl<Value> &results,
230+
int maxIterations) {
228231
SmallVector<OpFoldResult, 8> foldResults;
229-
if (failed(op->fold(foldResults)) ||
230-
failed(processFoldResults(op, results, foldResults)))
232+
if (failed(op->fold(foldResults)))
233+
return failure();
234+
int count = 1;
235+
do {
236+
LDBG() << "Folded in place #" << count
237+
<< " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
238+
} while (count++ < maxIterations && foldResults.empty() &&
239+
succeeded(op->fold(foldResults)));
240+
241+
if (failed(processFoldResults(op, results, foldResults)))
231242
return failure();
232243
return success();
233244
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Test with the default (one application of the folder) and then with 2 iterations.
2+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-single-fold))" | FileCheck %s --check-prefixes=CHECK,CHECK-ONE
3+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-single-fold{max-iterations=2}))" | FileCheck %s --check-prefixes=CHECK,CHECK-TWO
4+
5+
6+
// Folding entirely this requires to move the constant to the right
7+
// before invoking the op-specific folder.
8+
// With one iteration, we just push the constant to the right.
9+
// With a second iteration, we actually fold the "add" (x+0->x)
10+
// CHECK: func @recurse_fold_traits(%[[ARG0:.*]]: i32)
11+
func.func @recurse_fold_traits(%arg0 : i32) -> i32 {
12+
%cst0 = arith.constant 0 : i32
13+
// CHECK-ONE: %[[ADD:.*]] = arith.addi %[[ARG0]],
14+
%res = arith.addi %cst0, %arg0 : i32
15+
// CHECK-ONE: return %[[ADD]] : i32
16+
// CHECK-TWO: return %[[ARG0]] : i32
17+
return %res : i32
18+
}

mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ gpu.module @test {
77
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
88
//CHECK: [[c32:%.+]] = arith.constant 32 : index
99
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
10-
//CHECK: [[c0:%.+]] = arith.constant 0 : index
11-
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
1210
//CHECK: [[c128:%.+]] = arith.constant 128 : index
13-
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
11+
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
1412
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
1513
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
1614
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
@@ -23,10 +21,8 @@ gpu.module @test {
2321
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
2422
//CHECK: [[c32:%.+]] = arith.constant 32 : index
2523
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
26-
//CHECK: [[c0:%.+]] = arith.constant 0 : index
27-
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
2824
//CHECK: [[c128:%.+]] = arith.constant 128 : index
29-
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
25+
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
3026
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
3127
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
3228
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,10 @@ gpu.module @test_round_robin_assignment {
2727
//CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
2828
//CHECK: [[C0:%.+]] = arith.constant 0 : index
2929
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
30-
//CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
31-
//CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
3230
//CHECK: [[C128:%.+]] = arith.constant 128 : index
33-
//CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
31+
//CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
3432
//CHECK: [[C64_2:%.+]] = arith.constant 64 : index
35-
//CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
33+
//CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
3634
//CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
3735
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
3836
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,10 @@ gpu.module @test_distribution {
330330
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
331331
//CHECK: [[c0:%.+]] = arith.constant 0 : index
332332
//CHECK: [[c0_1:%.+]] = arith.constant 0 : index
333-
//CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
334-
//CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
335333
//CHECK: [[c64:%.+]] = arith.constant 64 : index
336-
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
334+
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
337335
//CHECK: [[c128:%.+]] = arith.constant 128 : index
338-
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
336+
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
339337
//CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
340338
%0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
341339
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
@@ -354,13 +352,11 @@ gpu.module @test_distribution {
354352
//CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
355353
//CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
356354
//CHECK: [[c32:%.+]] = arith.constant 32 : index
357-
//CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
355+
//CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
358356
//CHECK: [[c32_1:%.+]] = arith.constant 32 : index
359-
//CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
357+
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
360358
//CHECK: [[c0:%.+]] = arith.constant 0 : index
361359
//CHECK: [[c0_2:%.+]] = arith.constant 0 : index
362-
//CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
363-
//CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
364360
//CHECK: [[c64:%.+]] = arith.constant 64 : index
365361
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
366362
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -417,11 +413,10 @@ gpu.module @test_distribution {
417413
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
418414
//CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
419415
//CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
420-
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
416+
//CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]]
421417
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
422-
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
423418
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
424-
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
419+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
425420
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
426421
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
427422
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
@@ -435,9 +430,8 @@ gpu.module @test_distribution {
435430
//CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
436431
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
437432
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
438-
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
439433
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
440-
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
434+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
441435
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
442436
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
443437
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@ gpu.module @test_1_1_assignment {
1414
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
1515
//CHECK: [[C0:%.+]] = arith.constant 0 : index
1616
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
17-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
18-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
1917
//CHECK: [[C256:%.+]] = arith.constant 256 : index
20-
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
18+
//CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]]
2119
//CHECK: [[C128:%.+]] = arith.constant 128 : index
22-
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
20+
//CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]]
2321
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
2422
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
2523
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -37,17 +35,13 @@ gpu.module @test_1_1_assignment {
3735
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
3836
//CHECK: [[C0:%.+]] = arith.constant 0 : index
3937
//CHECK: [[C0_2:%.+]] = arith.constant 0 : index
40-
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
41-
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_2]] : index
4238
//CHECK: [[C256:%.+]] = arith.constant 256 : index
43-
//CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
39+
//CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]]
4440
//CHECK: [[C128:%.+]] = arith.constant 128 : index
45-
//CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
41+
//CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]]
4642
//CHECK: [[C0_3:%.+]] = arith.constant 0 : index
47-
//CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_3]]
4843
//CHECK: [[C0_4:%.+]] = arith.constant 0 : index
49-
//CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_4]]
50-
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[Y]], [[X]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
44+
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
5145
%tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
5246
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
5347
gpu.return

mlir/test/lib/Transforms/TestSingleFold.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
2626
public RewriterBase::Listener {
2727
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)
2828

29+
TestSingleFold() = default;
30+
TestSingleFold(const TestSingleFold &pass) : PassWrapper(pass) {}
31+
2932
StringRef getArgument() const final { return "test-single-fold"; }
3033
StringRef getDescription() const final {
3134
return "Test single-pass operation folding and dead constant elimination";
@@ -45,13 +48,18 @@ struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
4548
if (it != existingConstants.end())
4649
existingConstants.erase(it);
4750
}
51+
52+
Option<int> maxIterations{*this, "max-iterations",
53+
llvm::cl::desc("Max iterations in the tryToFold"),
54+
llvm::cl::init(1)};
4855
};
4956
} // namespace
5057

5158
void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
5259
// Attempt to fold the specified operation, including handling unused or
5360
// duplicated constants.
54-
(void)helper.tryToFold(op);
61+
bool inPlaceUpdate = false;
62+
(void)helper.tryToFold(op, &inPlaceUpdate, maxIterations);
5563
}
5664

5765
void TestSingleFold::runOnOperation() {

0 commit comments

Comments
 (0)