From 391189f765cb038fb56ca67b3b6bf0906a4686a6 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 23 Nov 2024 00:02:11 +0000 Subject: [PATCH 1/4] Allow SymbolUserOpInterface operators to be used in RemoveDeadValues pass. --- mlir/lib/Transforms/RemoveDeadValues.cpp | 6 ++---- mlir/test/Transforms/remove-dead-values.mlir | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index b82280dda8ba7..0aa9dcb36681b 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -577,10 +577,8 @@ void RemoveDeadValues::runOnOperation() { WalkResult acceptableIR = module->walk([&](Operation *op) { if (op == module) return WalkResult::advance(); - if (isa(op) || - (isa(op) && !isa(op))) { - op->emitError() << "cannot optimize an IR with " - "non-call symbol user ops or branch ops\n"; + if (isa(op)) { + op->emitError() << "cannot optimize an IR with branch ops\n"; return WalkResult::interrupt(); } return WalkResult::advance(); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 47137fc6430fe..7a8d49681a4b1 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -3,9 +3,10 @@ // The IR is updated regardless of memref.global private constant // module { - memref.global "private" constant @__something_global : memref = dense<0> + memref.global "private" constant @global_buffer : memref<5xi32> = dense<[1, 2, 3, 4, 5]> : tensor<5xi32> func.func @main(%arg0: i32) -> i32 { %0 = tensor.empty() : tensor<10xbf16> + %1 = memref.get_global @global_buffer : memref<5xi32> // CHECK-NOT: tensor.empty return %arg0 : i32 } From 801ca7f11f8abcbbb66ba4cfca6e9e03f2b21481 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 23 Nov 2024 00:37:05 +0000 Subject: [PATCH 2/4] Update error check. --- mlir/test/Transforms/remove-dead-values.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 7a8d49681a4b1..c215a2b8fd77c 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -3,10 +3,10 @@ // The IR is updated regardless of memref.global private constant // module { - memref.global "private" constant @global_buffer : memref<5xi32> = dense<[1, 2, 3, 4, 5]> : tensor<5xi32> + memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} func.func @main(%arg0: i32) -> i32 { %0 = tensor.empty() : tensor<10xbf16> - %1 = memref.get_global @global_buffer : memref<5xi32> + %1 = memref.get_global @__constant_4xi32 : memref<4xi32> // CHECK-NOT: tensor.empty return %arg0 : i32 } @@ -30,7 +30,7 @@ module @named_module_acceptable { // func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { %non_live = arith.constant 0 : i32 - // expected-error @+1 {{cannot optimize an IR with non-call symbol user ops or branch ops}} + // expected-error @+1 {{cannot optimize an IR with branch ops}} cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32) ^bb1(%non_live_0 : i32): cf.br ^bb3 From 4467082f3a1bb889a886335659bd4e7b85305e6a Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 23 Nov 2024 00:46:52 +0000 Subject: [PATCH 3/4] Update test with FILE CHECKS. --- mlir/test/Transforms/remove-dead-values.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index c215a2b8fd77c..b469ccb0dd950 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -3,6 +3,7 @@ // The IR is updated regardless of memref.global private constant // module { + // CHECK: memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} func.func @main(%arg0: i32) -> i32 { %0 = tensor.empty() : tensor<10xbf16> From 643b6008cf8839e53c94f2d0a5dadbcafad98353 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sat, 23 Nov 2024 17:17:56 +0000 Subject: [PATCH 4/4] Update test with FILE CHECKS. --- mlir/test/Transforms/remove-dead-values.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index b469ccb0dd950..826f6159a36b6 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -7,6 +7,7 @@ module { memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} func.func @main(%arg0: i32) -> i32 { %0 = tensor.empty() : tensor<10xbf16> + // CHECK-NOT: memref.get_global %1 = memref.get_global @__constant_4xi32 : memref<4xi32> // CHECK-NOT: tensor.empty return %arg0 : i32