Skip to content

Commit 88027fd

Browse files
committed
address review comments
1 parent cce30f7 commit 88027fd

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -641,33 +641,33 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
641641
// sideeffects until now or the nested operations do not access the
642642
// buffer written by outer scope.
643643
if (seenSideeffects) {
644-
bool accessesWrittenBuffer = false;
645-
nestedParallel.walk([&](Operation *nestedOp) {
646-
if (accessesWrittenBuffer)
647-
return;
644+
WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) {
648645
if (isMemoryEffectFree(nestedOp))
649-
return;
650-
651-
if (auto memEffectInterface =
652-
dyn_cast<MemoryEffectOpInterface>(nestedOp)) {
653-
SmallVector<MemoryEffects::EffectInstance> effects;
654-
memEffectInterface.getEffects(effects);
655-
for (const auto &effect : effects) {
656-
if (isa<MemoryEffects::Read>(effect.getEffect()) ||
657-
isa<MemoryEffects::Write>(effect.getEffect())) {
658-
Value baseBuffer = effect.getValue();
659-
for (auto val : writtenBuffer) {
660-
if (aliasAnalysis.alias(baseBuffer, val) !=
661-
AliasResult::NoAlias) {
662-
accessesWrittenBuffer = true;
663-
return;
664-
}
646+
return WalkResult::advance();
647+
648+
auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
649+
if (!memEffectInterface)
650+
return WalkResult::advance();
651+
652+
SmallVector<MemoryEffects::EffectInstance> effects;
653+
memEffectInterface.getEffects(effects);
654+
for (const MemoryEffects::EffectInstance &effect : effects) {
655+
if (isa<MemoryEffects::Read>(effect.getEffect()) ||
656+
isa<MemoryEffects::Write>(effect.getEffect())) {
657+
Value baseBuffer = effect.getValue();
658+
if (!baseBuffer)
659+
return WalkResult::interrupt();
660+
for (Value val : writtenBuffer) {
661+
if (aliasAnalysis.alias(baseBuffer, val) !=
662+
AliasResult::NoAlias) {
663+
return WalkResult::interrupt();
665664
}
666665
}
667666
}
668667
}
668+
return WalkResult::advance();
669669
});
670-
if (accessesWrittenBuffer)
670+
if (walkRes.wasInterrupted())
671671
return failure();
672672
}
673673
// A nested scf.parallel needs insertion of code to compute indices.
@@ -722,9 +722,15 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
722722
dyn_cast<MemoryEffectOpInterface>(clone)) {
723723
SmallVector<MemoryEffects::EffectInstance> effects;
724724
memEffectInterface.getEffects(effects);
725-
for (const auto &effect : effects) {
726-
if (isa<MemoryEffects::Write>(effect.getEffect()))
727-
writtenBuffer.insert(effect.getValue());
725+
for (const MemoryEffects::EffectInstance &effect : effects) {
726+
if (isa<MemoryEffects::Write>(effect.getEffect())) {
727+
Value writtenBase = effect.getValue();
728+
// Conservatively return failure if we cannot find the written
729+
// address.
730+
if (!writtenBase)
731+
return failure();
732+
writtenBuffer.insert(writtenBase);
733+
}
728734
}
729735
}
730736
}

0 commit comments

Comments
 (0)