@@ -641,33 +641,33 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
641641 // sideeffects until now or the nested operations do not access the
642642 // buffer written by outer scope.
643643 if (seenSideeffects) {
644- bool accessesWrittenBuffer = false ;
645- nestedParallel.walk ([&](Operation *nestedOp) {
646- if (accessesWrittenBuffer)
647- return ;
644+ WalkResult walkRes = nestedParallel.walk ([&](Operation *nestedOp) {
648645 if (isMemoryEffectFree (nestedOp))
649- return ;
650-
651- if (auto memEffectInterface =
652- dyn_cast<MemoryEffectOpInterface>(nestedOp)) {
653- SmallVector<MemoryEffects::EffectInstance> effects;
654- memEffectInterface.getEffects (effects);
655- for (const auto &effect : effects) {
656- if (isa<MemoryEffects::Read>(effect.getEffect ()) ||
657- isa<MemoryEffects::Write>(effect.getEffect ())) {
658- Value baseBuffer = effect.getValue ();
659- for (auto val : writtenBuffer) {
660- if (aliasAnalysis.alias (baseBuffer, val) !=
661- AliasResult::NoAlias) {
662- accessesWrittenBuffer = true ;
663- return ;
664- }
646+ return WalkResult::advance ();
647+
648+ auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
649+ if (!memEffectInterface)
650+ return WalkResult::advance ();
651+
652+ SmallVector<MemoryEffects::EffectInstance> effects;
653+ memEffectInterface.getEffects (effects);
654+ for (const MemoryEffects::EffectInstance &effect : effects) {
655+ if (isa<MemoryEffects::Read>(effect.getEffect ()) ||
656+ isa<MemoryEffects::Write>(effect.getEffect ())) {
657+ Value baseBuffer = effect.getValue ();
658+ if (!baseBuffer)
659+ return WalkResult::interrupt ();
660+ for (Value val : writtenBuffer) {
661+ if (aliasAnalysis.alias (baseBuffer, val) !=
662+ AliasResult::NoAlias) {
663+ return WalkResult::interrupt ();
665664 }
666665 }
667666 }
668667 }
668+ return WalkResult::advance ();
669669 });
670- if (accessesWrittenBuffer )
670+ if (walkRes. wasInterrupted () )
671671 return failure ();
672672 }
673673 // A nested scf.parallel needs insertion of code to compute indices.
@@ -722,9 +722,15 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
722722 dyn_cast<MemoryEffectOpInterface>(clone)) {
723723 SmallVector<MemoryEffects::EffectInstance> effects;
724724 memEffectInterface.getEffects (effects);
725- for (const auto &effect : effects) {
726- if (isa<MemoryEffects::Write>(effect.getEffect ()))
727- writtenBuffer.insert (effect.getValue ());
725+ for (const MemoryEffects::EffectInstance &effect : effects) {
726+ if (isa<MemoryEffects::Write>(effect.getEffect ())) {
727+ Value writtenBase = effect.getValue ();
728+ // Conservatively return failure if we cannot find the written
729+ // address.
730+ if (!writtenBase)
731+ return failure ();
732+ writtenBuffer.insert (writtenBase);
733+ }
728734 }
729735 }
730736 }
0 commit comments