Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 87 additions & 6 deletions mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -71,7 +72,54 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
}

template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
static void replaceAllUsesInUnstructuredComputeRegionWith(
Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {

Operation *exitOp = op.getOperation();
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
// For declare enter/exit pairs, verify there is exactly one exit op using
// the token
if (!op.getToken().hasOneUse())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too restrictive since from a language perspective multiple exits are allowed - and a dialect may choose to permit that. I think the algorithm here should collect all of the declare_exit's and ensure that the replaced use post-dominates one of them.

op.emitError("declare enter token must have exactly one use");
Operation *user = *op.getToken().getUsers().begin();
auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
if (!declareExit)
op.emitError("declare enter token must be used by declare exit op");
exitOp = declareExit;
} else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am excited about the potential to support acc.enter_data. As I looked at your algorithm here though I realized that maybe it should be left for another change. This is because enter_data and exit_data don't come in pairs - so a bit more analysis is needed to get this right. More specifically, finding an exit_data does not constitute the end of the lifetime for a particular variable.

// For enter/exit data pairs, find the corresponding exit_data op
Operation *nextOp = op.getOperation()->getNextNode();
while (nextOp && !isa<acc::ExitDataOp>(nextOp))
nextOp = nextOp->getNextNode();
if (!nextOp)
op.emitError("enter data must have a corresponding exit data op");
exitOp = nextOp;
}

for (auto p : values) {
Value hostVal = std::get<0>(p);
Value deviceVal = std::get<1>(p);
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
Operation *owner = use.getOwner();
// Check that:
// It's the case that the acc entry operation dominates the use.
// It's the case that one of the acc exit operations consuming the token
// post-dominates the use
if (!domInfo.dominates(op.getOperation(), owner) ||
!postDomInfo.postDominates(exitOp, owner))
continue;
if (insideAccComputeRegion(owner))
use.set(deviceVal);
}
}
}

template <typename Op>
static void
collectAndReplaceInRegion(Op &op, bool hostToDevice,
DominanceInfo *domInfo = nullptr,
PostDominanceInfo *postDomInfo = nullptr) {
llvm::SmallVector<std::pair<Value, Value>> values;

if constexpr (std::is_same_v<Op, acc::LoopOp>) {
Expand All @@ -82,16 +130,27 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp> &&
!std::is_same_v<Op, acc::HostDataOp>) {
!std::is_same_v<Op, acc::HostDataOp> &&
!std::is_same_v<Op, acc::DeclareEnterOp> &&
!std::is_same_v<Op, acc::EnterDataOp>) {
collectVars(op.getReductionOperands(), values, hostToDevice);
collectVars(op.getPrivateOperands(), values, hostToDevice);
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
}
}

for (auto p : values)
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
op.getRegion());
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
std::is_same_v<Op, acc::EnterDataOp>) {
assert(domInfo && postDomInfo &&
"Dominance info required for DeclareEnterOp");
replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
*postDomInfo);
} else {
for (auto p : values) {
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
op.getRegion());
}
}
}

class LegalizeDataValuesInRegion
Expand All @@ -105,10 +164,16 @@ class LegalizeDataValuesInRegion
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();

// Initialize dominance info
DominanceInfo domInfo;
PostDominanceInfo postDomInfo;
bool computedDomInfo = false;

funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
applyToAccDataConstruct))
applyToAccDataConstruct) &&
!isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
return;

if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
Expand All @@ -125,6 +190,22 @@ class LegalizeDataValuesInRegion
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
if (!computedDomInfo) {
domInfo = DominanceInfo(funcOp);
postDomInfo = PostDominanceInfo(funcOp);
computedDomInfo = true;
}
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
&postDomInfo);
} else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
if (!computedDomInfo) {
domInfo = DominanceInfo(funcOp);
postDomInfo = PostDominanceInfo(funcOp);
computedDomInfo = true;
}
collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
&postDomInfo);
} else {
llvm_unreachable("unsupported acc region op");
}
Expand Down
Loading