diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 1adc381092bf3..0ffd8131b8934 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -132,6 +132,29 @@ collectEffects(Operation *op, return false; } +/// Get all effects before the given operation caused by other operations in the +/// same block. That is, this will not consider operations beyond the block. +static bool +getEffectsBeforeInBlock(Operation *op, + SmallVectorImpl &effects, + bool stopAtBarrier) { + if (op == &op->getBlock()->front()) + return true; + + for (Operation *it = op->getPrevNode(); it != nullptr; + it = it->getPrevNode()) { + if (isa(it)) { + if (stopAtBarrier) + return true; + continue; + } + + if (!collectEffects(it, effects)) + return false; + } + return true; +} + /// Collects memory effects from operations that may be executed before `op` in /// a trivial structured control flow, e.g., without branches. Stops at the /// parallel region boundary or at the barrier operation if `stopAtBarrier` is @@ -153,19 +176,7 @@ getEffectsBefore(Operation *op, } // Collect all effects before the op. - if (op != &op->getBlock()->front()) { - for (Operation *it = op->getPrevNode(); it != nullptr; - it = it->getPrevNode()) { - if (isa(it)) { - if (stopAtBarrier) - return true; - else - continue; - } - if (!collectEffects(it, effects)) - return false; - } - } + getEffectsBeforeInBlock(op, effects, stopAtBarrier); // Stop if reached the parallel region boundary. if (isParallelRegionBoundary(op->getParentOp())) @@ -191,8 +202,8 @@ getEffectsBefore(Operation *op, // appropriately. if (isSequentialLoopLike(op->getParentOp())) { // Assuming loop terminators have no side effects. - return getEffectsBefore(op->getBlock()->getTerminator(), effects, - /*stopAtBarrier=*/true); + return getEffectsBeforeInBlock(op->getBlock()->getTerminator(), effects, + /*stopAtBarrier=*/true); } // If the parent operation is not guaranteed to execute its (single-block) @@ -212,6 +223,28 @@ getEffectsBefore(Operation *op, return !conservative; } +/// Get all effects after the given operation caused by other operations in the +/// same block. That is, this will not consider operations beyond the block. +static bool +getEffectsAfterInBlock(Operation *op, + SmallVectorImpl &effects, + bool stopAtBarrier) { + if (op == &op->getBlock()->back()) + return true; + + for (Operation *it = op->getNextNode(); it != nullptr; + it = it->getNextNode()) { + if (isa(it)) { + if (stopAtBarrier) + return true; + continue; + } + if (!collectEffects(it, effects)) + return false; + } + return true; +} + /// Collects memory effects from operations that may be executed after `op` in /// a trivial structured control flow, e.g., without branches. Stops at the /// parallel region boundary or at the barrier operation if `stopAtBarrier` is @@ -233,17 +266,7 @@ getEffectsAfter(Operation *op, } // Collect all effects after the op. - if (op != &op->getBlock()->back()) - for (Operation *it = op->getNextNode(); it != nullptr; - it = it->getNextNode()) { - if (isa(it)) { - if (stopAtBarrier) - return true; - continue; - } - if (!collectEffects(it, effects)) - return false; - } + getEffectsAfterInBlock(op, effects, stopAtBarrier); // Stop if reached the parallel region boundary. if (isParallelRegionBoundary(op->getParentOp())) @@ -272,8 +295,8 @@ getEffectsAfter(Operation *op, return true; bool exact = collectEffects(&op->getBlock()->front(), effects); - return getEffectsAfter(&op->getBlock()->front(), effects, - /*stopAtBarrier=*/true) && + return getEffectsAfterInBlock(&op->getBlock()->front(), effects, + /*stopAtBarrier=*/true) && exact; } diff --git a/mlir/test/Dialect/GPU/barrier-elimination.mlir b/mlir/test/Dialect/GPU/barrier-elimination.mlir index 1f5b84937deb0..7f6619adcd78f 100644 --- a/mlir/test/Dialect/GPU/barrier-elimination.mlir +++ b/mlir/test/Dialect/GPU/barrier-elimination.mlir @@ -182,3 +182,20 @@ attributes {__parallel_region_boundary_for_test} { %4 = memref.load %C[] : memref return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32 } + +// CHECK-LABEL: @nested_loop_barrier_only +func.func @nested_loop_barrier_only() attributes {__parallel_region_boundary_for_test} { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c1 = arith.constant 1 : index + // Note: the barrier can be removed and as consequence the loops get folded + // by the greedy rewriter. + // CHECK-NOT: scf.for + // CHECK-NOT: gpu.barrier + scf.for %j = %c0 to %c42 step %c1 { + scf.for %i = %c0 to %c42 step %c1 { + gpu.barrier + } + } + return +}