diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 481b14cdb4622..b0fb5b0785142 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1078,8 +1078,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [ def AtomicYieldOp : MemRef_Op<"atomic_yield", [ HasParent<"GenericAtomicRMWOp">, - Pure, - Terminator + Pure, Terminator, ReturnLike ]> { let summary = "yield operation for GenericAtomicRMWOp"; let description = [{ diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 8c2a1cf7546f3..e55a9160f5b34 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -510,3 +510,18 @@ module { // CHECK: %[[yield:.*]] = arith.addf %{{.*}}, %{{.*}} : f32 // CHECK: linalg.yield %[[yield]] : f32 // CHECK-NOT: arith.subf + +// ----- + +// CHECK-LABEL: func.func @test_atomic_yield +func.func @test_atomic_yield(%I: memref<10xf32>, %idx : index) { + // CHECK: memref.generic_atomic_rmw + %x = memref.generic_atomic_rmw %I[%idx] : memref<10xf32> { + ^bb0(%current_value : f32): + // CHECK: arith.constant + %c1 = arith.constant 1.0 : f32 + // CHECK: memref.atomic_yield + memref.atomic_yield %c1 : f32 + } + func.return +}