Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp, \
mlir::acc::DeclareEnterOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be added to this list because although it can result from a structured declare construct, the operation itself allows unstructured flow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, i realized that was a mistake, changed it in the next commit

#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
Expand Down
56 changes: 51 additions & 5 deletions mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
}
}

// Helper function to process declare enter/exit pairs
static void processDeclareEnterExit(
acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
// For declare enter/exit pairs, verify there is exactly one exit op using the
// token
if (!op.getToken().hasOneUse())
op.emitError("declare enter token must have exactly one use");
Operation *user = *op.getToken().getUsers().begin();
auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
if (!declareExit)
op.emitError("declare enter token must be used by declare exit op");

for (auto p : values) {
Value hostVal = std::get<0>(p);
Value deviceVal = std::get<1>(p);
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
Operation *owner = use.getOwner();
if (!domInfo.dominates(op.getOperation(), owner) ||
!postDomInfo.postDominates(declareExit.getOperation(), owner))
continue;
if (insideAccComputeRegion(owner))
use.set(deviceVal);
}
}
}

template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
static void
collectAndReplaceInRegion(Op &op, bool hostToDevice,
DominanceInfo *domInfo = nullptr,
PostDominanceInfo *postDomInfo = nullptr) {
llvm::SmallVector<std::pair<Value, Value>> values;

if constexpr (std::is_same_v<Op, acc::LoopOp>) {
Expand All @@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp> &&
!std::is_same_v<Op, acc::HostDataOp>) {
!std::is_same_v<Op, acc::HostDataOp> &&
!std::is_same_v<Op, acc::DeclareEnterOp>) {
collectVars(op.getReductionOperands(), values, hostToDevice);
collectVars(op.getPrivateOperands(), values, hostToDevice);
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
}
}

for (auto p : values)
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
op.getRegion());
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
assert(domInfo && postDomInfo &&
"Dominance info required for DeclareEnterOp");
processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
} else {
for (auto p : values) {
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
op.getRegion());
}
}
}

class LegalizeDataValuesInRegion
Expand All @@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();

// Get dominance info for the function
DominanceInfo domInfo(funcOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's compute this lazily - only compute it first time when you encounter acc.declare_enter.

PostDominanceInfo postDomInfo(funcOp);

funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
Expand All @@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
&postDomInfo);
} else {
llvm_unreachable("unsupported acc region op");
}
Expand Down
Loading