-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][acc] Add LegalizeDataValues support for DeclareEnterOp #138008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
4c458ea
03b89b1
cd70987
014a8b8
823d31d
4814f8f
cb0a1b0
a132a37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
|
|
||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/OpenACC/OpenACC.h" | ||
| #include "mlir/IR/Dominance.h" | ||
| #include "mlir/Pass/Pass.h" | ||
| #include "mlir/Transforms/RegionUtils.h" | ||
| #include "llvm/Support/ErrorHandling.h" | ||
|
|
@@ -71,7 +72,54 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, | |
| } | ||
|
|
||
| template <typename Op> | ||
| static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { | ||
| static void replaceAllUsesInUnstructuredComputeRegionWith( | ||
| Op &op, llvm::SmallVector<std::pair<Value, Value>> &values, | ||
| DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) { | ||
|
|
||
| Operation *exitOp = op.getOperation(); | ||
| if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) { | ||
| // For declare enter/exit pairs, verify there is exactly one exit op using | ||
| // the token | ||
| if (!op.getToken().hasOneUse()) | ||
| op.emitError("declare enter token must have exactly one use"); | ||
| Operation *user = *op.getToken().getUsers().begin(); | ||
| auto declareExit = dyn_cast<acc::DeclareExitOp>(user); | ||
| if (!declareExit) | ||
| op.emitError("declare enter token must be used by declare exit op"); | ||
| exitOp = declareExit; | ||
| } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) { | ||
|
||
| // For enter/exit data pairs, find the corresponding exit_data op | ||
| Operation *nextOp = op.getOperation()->getNextNode(); | ||
| while (nextOp && !isa<acc::ExitDataOp>(nextOp)) | ||
| nextOp = nextOp->getNextNode(); | ||
| if (!nextOp) | ||
| op.emitError("enter data must have a corresponding exit data op"); | ||
| exitOp = nextOp; | ||
| } | ||
|
|
||
| for (auto p : values) { | ||
| Value hostVal = std::get<0>(p); | ||
| Value deviceVal = std::get<1>(p); | ||
| for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) { | ||
| Operation *owner = use.getOwner(); | ||
| // Check that: | ||
| // It's the case that the acc entry operation dominates the use. | ||
| // It's the case that one of the acc exit operations consuming the token | ||
| // post-dominates the use | ||
| if (!domInfo.dominates(op.getOperation(), owner) || | ||
| !postDomInfo.postDominates(exitOp, owner)) | ||
| continue; | ||
| if (insideAccComputeRegion(owner)) | ||
| use.set(deviceVal); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename Op> | ||
| static void | ||
| collectAndReplaceInRegion(Op &op, bool hostToDevice, | ||
| DominanceInfo *domInfo = nullptr, | ||
| PostDominanceInfo *postDomInfo = nullptr) { | ||
| llvm::SmallVector<std::pair<Value, Value>> values; | ||
|
|
||
| if constexpr (std::is_same_v<Op, acc::LoopOp>) { | ||
|
|
@@ -82,16 +130,27 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { | |
| if constexpr (!std::is_same_v<Op, acc::KernelsOp> && | ||
| !std::is_same_v<Op, acc::DataOp> && | ||
| !std::is_same_v<Op, acc::DeclareOp> && | ||
| !std::is_same_v<Op, acc::HostDataOp>) { | ||
| !std::is_same_v<Op, acc::HostDataOp> && | ||
| !std::is_same_v<Op, acc::DeclareEnterOp> && | ||
| !std::is_same_v<Op, acc::EnterDataOp>) { | ||
| collectVars(op.getReductionOperands(), values, hostToDevice); | ||
| collectVars(op.getPrivateOperands(), values, hostToDevice); | ||
| collectVars(op.getFirstprivateOperands(), values, hostToDevice); | ||
| } | ||
| } | ||
|
|
||
| for (auto p : values) | ||
| replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p), | ||
| op.getRegion()); | ||
| if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> || | ||
| std::is_same_v<Op, acc::EnterDataOp>) { | ||
| assert(domInfo && postDomInfo && | ||
| "Dominance info required for DeclareEnterOp"); | ||
| replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo, | ||
| *postDomInfo); | ||
| } else { | ||
| for (auto p : values) { | ||
| replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p), | ||
| op.getRegion()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| class LegalizeDataValuesInRegion | ||
|
|
@@ -105,10 +164,16 @@ class LegalizeDataValuesInRegion | |
| func::FuncOp funcOp = getOperation(); | ||
| bool replaceHostVsDevice = this->hostToDevice.getValue(); | ||
|
|
||
| // Initialize dominance info | ||
| DominanceInfo domInfo; | ||
| PostDominanceInfo postDomInfo; | ||
| bool computedDomInfo = false; | ||
|
|
||
| funcOp.walk([&](Operation *op) { | ||
| if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) && | ||
| !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) && | ||
| applyToAccDataConstruct)) | ||
| applyToAccDataConstruct) && | ||
| !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op)) | ||
| return; | ||
|
|
||
| if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { | ||
|
|
@@ -125,6 +190,22 @@ class LegalizeDataValuesInRegion | |
| collectAndReplaceInRegion(declareOp, replaceHostVsDevice); | ||
| } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) { | ||
| collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice); | ||
| } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) { | ||
| if (!computedDomInfo) { | ||
| domInfo = DominanceInfo(funcOp); | ||
| postDomInfo = PostDominanceInfo(funcOp); | ||
| computedDomInfo = true; | ||
| } | ||
| collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo, | ||
| &postDomInfo); | ||
| } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) { | ||
| if (!computedDomInfo) { | ||
| domInfo = DominanceInfo(funcOp); | ||
| postDomInfo = PostDominanceInfo(funcOp); | ||
| computedDomInfo = true; | ||
| } | ||
| collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo, | ||
| &postDomInfo); | ||
| } else { | ||
| llvm_unreachable("unsupported acc region op"); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too restrictive since from a language perspective multiple exits are allowed - and a dialect may choose to permit that. I think the algorithm here should collect all of the
declare_exit's and ensure that the replaced use post-dominates one of them.