1414
1515#include " mlir/Conversion/SCFToGPU/SCFToGPU.h"
1616
17+ #include " mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
1718#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
1819#include " mlir/Dialect/Affine/IR/AffineOps.h"
1920#include " mlir/Dialect/Arith/IR/Arith.h"
2728#include " mlir/Interfaces/SideEffectInterfaces.h"
2829#include " mlir/Transforms/DialectConversion.h"
2930#include " mlir/Transforms/RegionUtils.h"
31+ #include " llvm/ADT/DenseSet.h"
3032#include " llvm/Support/DebugLog.h"
3133#include < optional>
3234
@@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
625627 bool seenSideeffects = false ;
626628 // Whether we have left a nesting scope (and hence are no longer innermost).
627629 bool leftNestingScope = false ;
630+ LocalAliasAnalysis aliasAnalysis;
631+ llvm::DenseSet<Value> writtenBuffer;
628632 while (!worklist.empty ()) {
629633 Operation *op = worklist.pop_back_val ();
630634 // Now walk over the body and clone it.
631635 // TODO: This is only correct if there either is no further scf.parallel
632- // nested or this code is side-effect free. Otherwise we might need
633- // predication. We are overly conservative for now and only allow
634- // side-effects in the innermost scope .
636+ // nested or this code has side-effect but the memory buffer is not
637+ // alias to inner loop access buffer. Otherwise we might need
638+ // predication .
635639 if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {
636640 // Before entering a nested scope, make sure there have been no
637- // sideeffects until now.
638- if (seenSideeffects)
639- return failure ();
641+ // sideeffects until now or the nested operations do not access the
642+ // buffer written by outer scope.
643+ if (seenSideeffects) {
644+ WalkResult walkRes = nestedParallel.walk ([&](Operation *nestedOp) {
645+ if (isMemoryEffectFree (nestedOp))
646+ return WalkResult::advance ();
647+
648+ auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
649+ if (!memEffectInterface)
650+ return WalkResult::advance ();
651+
652+ SmallVector<MemoryEffects::EffectInstance> effects;
653+ memEffectInterface.getEffects (effects);
654+ for (const MemoryEffects::EffectInstance &effect : effects) {
655+ if (isa<MemoryEffects::Read>(effect.getEffect ()) ||
656+ isa<MemoryEffects::Write>(effect.getEffect ())) {
657+ Value baseBuffer = effect.getValue ();
658+ if (!baseBuffer)
659+ return WalkResult::interrupt ();
660+ for (Value val : writtenBuffer) {
661+ if (aliasAnalysis.alias (baseBuffer, val) !=
662+ AliasResult::NoAlias) {
663+ return WalkResult::interrupt ();
664+ }
665+ }
666+ }
667+ }
668+ return WalkResult::advance ();
669+ });
670+ if (walkRes.wasInterrupted ())
671+ return failure ();
672+ }
640673 // A nested scf.parallel needs insertion of code to compute indices.
641674 // Insert that now. This will also update the worklist with the loops
642675 // body.
@@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
650683 rewriter.setInsertionPointAfter (parent);
651684 leftNestingScope = true ;
652685 seenSideeffects = false ;
686+ writtenBuffer.clear ();
653687 } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
654688 // Convert scf.reduction op
655689 auto parentLoop = op->getParentOfType <ParallelOp>();
@@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
682716 Operation *clone = rewriter.clone (*op, cloningMap);
683717 cloningMap.map (op->getResults (), clone->getResults ());
684718 // Check for side effects.
719+ if (!isMemoryEffectFree (clone)) {
720+ // Record the buffer accessed by the operations with write effects.
721+ if (auto memEffectInterface =
722+ dyn_cast<MemoryEffectOpInterface>(clone)) {
723+ SmallVector<MemoryEffects::EffectInstance> effects;
724+ memEffectInterface.getEffects (effects);
725+ for (const MemoryEffects::EffectInstance &effect : effects) {
726+ if (isa<MemoryEffects::Write>(effect.getEffect ())) {
727+ Value writtenBase = effect.getValue ();
728+ // Conservatively return failure if we cannot find the written
729+ // address.
730+ if (!writtenBase)
731+ return failure ();
732+ writtenBuffer.insert (writtenBase);
733+ }
734+ }
735+ }
736+ }
685737 // TODO: Handle region side effects properly.
686738 seenSideeffects |=
687739 !isMemoryEffectFree (clone) || clone->getNumRegions () != 0 ;
0 commit comments