Skip to content

Conversation

@nhat-nguyen
Copy link
Contributor

@nhat-nguyen nhat-nguyen commented May 20, 2025

Currently the liveness analysis always marks operands yielded in regions that aren't classified as RegionBranchOpInterface or CallableOpInterface as non-live. Examples for these ops include linalg.generic (with linalg.yield as terminator) or gpu ops (with gpu.yield as terminator).

This in turn makes the remove-dead-values pass always incorrectly remove the bodies of these ops, leading to invalid IR. Because these ops define their own semantics, I have conservatively marked all operands of these yield ops to be live.

@llvmbot llvmbot added the mlir label May 20, 2025
@llvmbot
Copy link
Member

llvmbot commented May 20, 2025

@llvm/pr-subscribers-mlir

Author: Nhat Nguyen (nhat-nguyen)

Changes

Currently the liveness analysis always marks operands yielded in regions that aren't classified as RegionBranchOpInterface or CallableOpInterface as non-live. Examples for these ops include linalg.generic (with linalg.yield as terminator) or gpu ops (with gpu.yield as terminator).

This in turn makes the remove-dead-values pass always incorrectly removes the bodies of these ops, leading to invalid IR. Because these ops define their own semantics, I have conservatively marked all operands of these yield ops to be live.


Full diff: https://github.com/llvm/llvm-project/pull/140793.diff

2 Files Affected:

  • (modified) mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp (+7-3)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+40)
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index c12149a1a0242..d61cdb143e7dd 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -51,7 +51,11 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
 /// A value is considered "live" iff it:
 ///   (1) has memory effects OR
 ///   (2) is returned by a public function OR
-///   (3) is used to compute a value of type (1) or (2).
+///   (3) is used to compute a value of type (1) or (2) OR
+///   (4) is returned by a return-like op whose parent isn't a callable
+///       nor a RegionBranchOpInterface (e.g.: linalg.yield, gpu.yield,...)
+///       These ops have their own semantics, so we conservatively mark the
+///       the yield value as live.
 /// It is also to be noted that a value could be of multiple types (1/2/3) at
 /// the same time.
 ///
@@ -73,8 +77,8 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
 LogicalResult
 LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
                                  ArrayRef<const Liveness *> results) {
-  // This marks values of type (1.a) liveness as "live".
-  if (!isMemoryEffectFree(op)) {
+  // This marks values of type (1.a) and (4) liveness as "live".
+  if (!isMemoryEffectFree(op) || op->hasTrait<OpTrait::ReturnLike>()) {
     for (auto *operand : operands)
       propagateIfChanged(operand, operand->markLive());
   }
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 21d53b0742e07..87df57df54d7a 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -468,3 +468,43 @@ func.func private @no_block_func_declaration() -> ()
 
 // CHECK: llvm.func @no_block_external_func()
 llvm.func @no_block_external_func() attributes {sym_visibility = "private"}
+
+// -----
+
+// Check that yielded values aren't incorrectly removed in gpu regions
+gpu.module @test_module_3 {
+  gpu.func @gpu_all_reduce_region() {
+    %arg0 = arith.constant 1 : i32
+    %result = gpu.all_reduce %arg0 uniform {
+    ^bb(%lhs : i32, %rhs : i32):
+      %xor = arith.xori %lhs, %rhs : i32
+      "gpu.yield"(%xor) : (i32) -> ()
+    } : (i32) -> (i32)
+    gpu.return
+  }
+}
+
+// CHECK-LABEL: func @gpu_all_reduce_region()
+// CHECK: %[[yield:.*]] = arith.xori %{{.*}}, %{{.*}} : i32
+// CHECK: gpu.yield %[[yield]] : i32
+
+// -----
+
+// Check that yielded values aren't incorrectly removed in linalg regions
+module {
+  func.func @linalg_red_add(%arg0: tensor<?xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+    %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (0)>],
+      iterator_types = ["reduction"]
+    } ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<1xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %1 = arith.addf %in, %out : f32
+      linalg.yield %1 : f32
+    } -> tensor<1xf32>
+    return %0 : tensor<1xf32>
+  }
+}
+
+// CHECK-LABEL: func @linalg_red_add
+// CHECK: %[[yield:.*]] = arith.addf %{{.*}}, %{{.*}} : f32
+// CHECK: linalg.yield %[[yield]] : f32

@nhat-nguyen
Copy link
Contributor Author

@Mogball I saw you have reviewed a few PRs in this file, so just wanted to get your opinion on this fix. Thank you! :)

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me

@nhat-nguyen
Copy link
Contributor Author

Seems reasonable to me

Thanks for the review! Would you mind helping me merge this too?

@Mogball Mogball merged commit 7385772 into llvm:main May 21, 2025
11 checks passed
nhat-nguyen added a commit to microsoft/triton-shared that referenced this pull request Jun 19, 2025
Previously, before the conversion to the PtrDialect for fallback, the
structured-to-memref pass has to convert loop's iter-args with triton
pointer to unranked memref. This conversion ensures all types coming out
of triton-shared are mlir built-in types and therefore allows the CPU
backend to correctly lower the IR to llvm. However, in reality,
structured ops do not need to use the loop iter-args since ptr-analysis
generates load/store ops that directly use the kernel arguments as
source; this means the conversion is mostly unnecessary.

With the introduction of the fallback using the PtrDialect
(triton-to-ptr), we also convert the loop iter-args of triton pointer
type to PtrDialect's ptr type. This conversion, along with the
conversion to unranked memref above, means we will end up with
`unrealized_conversion_cast` ops that convert back and forth between
these two types when handling triton programs that have mixed uses of
structured and unstructured accesses in loops.

To solve this issue, we:

- remove the conversion of loop-iter of triton ptr type to unranked
memref since it is unnecessary as described as above
- run remove-dead-values to remove unused loop-iter args; this pass
previously could not run in presence of ops with arbitrary regions but
has now been fixed in this PR:
llvm/llvm-project#140793. Running
remove-dead-values gives two benefits:
- ability to remove all unused loop iter-arg that isn't used after
ptr-analysis
   - make our codegen more efficient in general
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants