Skip to content

Commit 51751bf

Browse files
clementvalgithub-actions[bot]
authored andcommitted
Automerge: [mlir][acc] Add isValidValueUse to OpenACCSupport (#171538)
Add a new API `isValidValueUse ` to OpenACCSupport. This is used in ACCImplicitData to check value that are already legal in the OpenACC region and do not require implicit clause to be generated. An example would be a CUDA Fortran device variable that is already on the GPU.
2 parents a4a237b + bf81bde commit 51751bf

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ struct OpenACCSupportTraits {
8585
/// Check if a symbol use is valid for use in an OpenACC region.
8686
virtual bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
8787
Operation **definingOpPtr) = 0;
88+
89+
/// Check if a value use is legal in an OpenACC region.
90+
virtual bool isValidValueUse(Value v, mlir::Region &region) = 0;
8891
};
8992

9093
/// SFINAE helpers to detect if implementation has optional methods
@@ -97,6 +100,14 @@ struct OpenACCSupportTraits {
97100
llvm::is_detected<isValidSymbolUse_t, ImplT, Operation *, SymbolRefAttr,
98101
Operation **>;
99102

103+
template <typename ImplT, typename... Args>
104+
using isValidValueUse_t =
105+
decltype(std::declval<ImplT>().isValidValueUse(std::declval<Args>()...));
106+
107+
template <typename ImplT>
108+
using has_isValidValueUse =
109+
llvm::is_detected<isValidValueUse_t, ImplT, Value, Region &>;
110+
100111
/// This class wraps a concrete OpenACCSupport implementation and forwards
101112
/// interface calls to it. This provides type erasure, allowing different
102113
/// implementation types to be used interchangeably without inheritance.
@@ -128,6 +139,13 @@ struct OpenACCSupportTraits {
128139
return acc::isValidSymbolUse(user, symbol, definingOpPtr);
129140
}
130141

142+
bool isValidValueUse(Value v, Region &region) final {
143+
if constexpr (has_isValidSymbolUse<ImplT>::value)
144+
return impl.isValidValueUse(v, region);
145+
else
146+
return false;
147+
}
148+
131149
private:
132150
ImplT impl;
133151
};
@@ -189,6 +207,12 @@ class OpenACCSupport {
189207
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
190208
Operation **definingOpPtr = nullptr);
191209

210+
/// Check if a value use is legal in an OpenACC region.
211+
///
212+
/// \param v The MLIR value to check for legality.
213+
/// \param region The MLIR region in which the legality is checked.
214+
bool isValidValueUse(Value v, Region &region);
215+
192216
/// Signal that this analysis should always be preserved so that
193217
/// underlying implementation registration is not lost.
194218
bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {

mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,11 @@ bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
4848
return acc::isValidSymbolUse(user, symbol, definingOpPtr);
4949
}
5050

51+
bool OpenACCSupport::isValidValueUse(Value v, Region &region) {
52+
if (impl)
53+
return impl->isValidValueUse(v, region);
54+
return false;
55+
}
56+
5157
} // namespace acc
5258
} // namespace mlir

mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,10 @@ class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
254254

255255
/// Generates the implicit data ops for a compute construct.
256256
template <typename OpT>
257-
void generateImplicitDataOps(
258-
ModuleOp &module, OpT computeConstructOp,
259-
std::optional<acc::ClauseDefaultValue> &defaultClause);
257+
void
258+
generateImplicitDataOps(ModuleOp &module, OpT computeConstructOp,
259+
std::optional<acc::ClauseDefaultValue> &defaultClause,
260+
acc::OpenACCSupport &accSupport);
260261

261262
/// Generates a private recipe for a variable.
262263
acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module, Value var,
@@ -277,12 +278,16 @@ class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
277278

278279
/// Determines if a variable is a candidate for implicit data mapping.
279280
/// Returns true if the variable is a candidate, false otherwise.
280-
static bool isCandidateForImplicitData(Value val, Region &accRegion) {
281+
static bool isCandidateForImplicitData(Value val, Region &accRegion,
282+
acc::OpenACCSupport &accSupport) {
281283
// Ensure the variable is an allowed type for data clause.
282284
if (!acc::isPointerLikeType(val.getType()) &&
283285
!acc::isMappableType(val.getType()))
284286
return false;
285287

288+
if (accSupport.isValidValueUse(val, accRegion))
289+
return false;
290+
286291
// If this is already coming from a data clause, we do not need to generate
287292
// another.
288293
if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp()))
@@ -683,7 +688,8 @@ static void insertInSortedOrder(SmallVector<Value> &sortedDataClauseOperands,
683688
template <typename OpT>
684689
void ACCImplicitData::generateImplicitDataOps(
685690
ModuleOp &module, OpT computeConstructOp,
686-
std::optional<acc::ClauseDefaultValue> &defaultClause) {
691+
std::optional<acc::ClauseDefaultValue> &defaultClause,
692+
acc::OpenACCSupport &accSupport) {
687693
// Implicit data attributes are only applied if "[t]here is no default(none)
688694
// clause visible at the compute construct."
689695
if (defaultClause.has_value() &&
@@ -699,7 +705,7 @@ void ACCImplicitData::generateImplicitDataOps(
699705

700706
// 2) Run the filtering to find relevant pointers that need copied.
701707
auto isCandidate{[&](Value val) -> bool {
702-
return isCandidateForImplicitData(val, accRegion);
708+
return isCandidateForImplicitData(val, accRegion, accSupport);
703709
}};
704710
auto candidateVars(
705711
llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
@@ -763,6 +769,9 @@ void ACCImplicitData::generateImplicitDataOps(
763769

764770
void ACCImplicitData::runOnOperation() {
765771
ModuleOp module = this->getOperation();
772+
773+
acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
774+
766775
module.walk([&](Operation *op) {
767776
if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
768777
assert(op->getNumRegions() == 1 && "must have 1 region");
@@ -771,7 +780,7 @@ void ACCImplicitData::runOnOperation() {
771780
llvm::TypeSwitch<Operation *, void>(op)
772781
.Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
773782
[&](auto op) {
774-
generateImplicitDataOps(module, op, defaultClause);
783+
generateImplicitDataOps(module, op, defaultClause, accSupport);
775784
})
776785
.Default([&](Operation *) {});
777786
}

0 commit comments

Comments
 (0)