1414
1515#include " mlir/Conversion/SCFToGPU/SCFToGPU.h"
1616
17+ #include " mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
1718#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
1819#include " mlir/Dialect/Affine/IR/AffineOps.h"
1920#include " mlir/Dialect/Arith/IR/Arith.h"
2728#include " mlir/Interfaces/SideEffectInterfaces.h"
2829#include " mlir/Transforms/DialectConversion.h"
2930#include " mlir/Transforms/RegionUtils.h"
31+ #include " llvm/ADT/DenseSet.h"
3032#include " llvm/Support/DebugLog.h"
3133#include < optional>
3234
@@ -625,6 +627,8 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
625627 bool seenSideeffects = false ;
626628 // Whether we have left a nesting scope (and hence are no longer innermost).
627629 bool leftNestingScope = false ;
630+ LocalAliasAnalysis aliasAnalysis;
631+ llvm::DenseSet<Value> writtenBuffer;
628632 while (!worklist.empty ()) {
629633 Operation *op = worklist.pop_back_val ();
630634 // Now walk over the body and clone it.
@@ -635,8 +639,39 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
635639 if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {
636640 // Before entering a nested scope, make sure there have been no
637641 // sideeffects until now.
638- if (seenSideeffects)
639- return failure ();
642+ if (seenSideeffects) {
643+ // Go through all operations in the nested parallel and check if any
644+ // of the side-effecting operations access buffers that have been
645+ // written to in the outer scope.
646+ bool accessesWrittenBuffer = false ;
647+ nestedParallel.walk ([&](Operation *nestedOp) {
648+ if (accessesWrittenBuffer)
649+ return ;
650+ if (isMemoryEffectFree (nestedOp))
651+ return ;
652+
653+ if (auto memEffectInterface =
654+ dyn_cast<MemoryEffectOpInterface>(nestedOp)) {
655+ SmallVector<MemoryEffects::EffectInstance> effects;
656+ memEffectInterface.getEffects (effects);
657+ for (const auto &effect : effects) {
658+ if (isa<MemoryEffects::Read>(effect.getEffect ()) ||
659+ isa<MemoryEffects::Write>(effect.getEffect ())) {
660+ Value baseBuffer = effect.getValue ();
661+ for (auto val : writtenBuffer) {
662+ if (aliasAnalysis.alias (baseBuffer, val) !=
663+ AliasResult::NoAlias) {
664+ accessesWrittenBuffer = true ;
665+ return ;
666+ }
667+ }
668+ }
669+ }
670+ }
671+ });
672+ if (accessesWrittenBuffer)
673+ return failure ();
674+ }
640675 // A nested scf.parallel needs insertion of code to compute indices.
641676 // Insert that now. This will also update the worklist with the loops
642677 // body.
@@ -650,6 +685,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
650685 rewriter.setInsertionPointAfter (parent);
651686 leftNestingScope = true ;
652687 seenSideeffects = false ;
688+ writtenBuffer.clear ();
653689 } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
654690 // Convert scf.reduction op
655691 auto parentLoop = op->getParentOfType <ParallelOp>();
@@ -682,6 +718,18 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
682718 Operation *clone = rewriter.clone (*op, cloningMap);
683719 cloningMap.map (op->getResults (), clone->getResults ());
684720 // Check for side effects.
721+ if (!isMemoryEffectFree (clone)) {
722+ // Record the buffer accessed by the operations with write effects.
723+ if (auto memEffectInterface =
724+ dyn_cast<MemoryEffectOpInterface>(clone)) {
725+ SmallVector<MemoryEffects::EffectInstance> effects;
726+ memEffectInterface.getEffects (effects);
727+ for (const auto &effect : effects) {
728+ if (isa<MemoryEffects::Write>(effect.getEffect ()))
729+ writtenBuffer.insert (effect.getValue ());
730+ }
731+ }
732+ }
685733 // TODO: Handle region side effects properly.
686734 seenSideeffects |=
687735 !isMemoryEffectFree (clone) || clone->getNumRegions () != 0 ;
0 commit comments