Skip to content

Commit 89bb99d

Browse files
[acc][flang] Implement acc interface for tracking type descriptors (#168982)
FIR operations that use derived types need to have type descriptor globals available on device when offloading. Examples of this can be seen in `CUFDeviceGlobal` which ensures that such type descriptor uses work on device for CUF. Similarly, this is needed for OpenACC. This change introduces a new interface to the OpenACC dialect named `IndirectGlobalAccessOpInterface` which can be attached to operations that may result in generation of accesses that use type descriptor globals. This functionality is needed for the `ACCImplicitDeclare` pass that is coming in a follow-up change which implicitly ensures that all referenced globals are available in OpenACC compute contexts. The interface provides a `getReferencedSymbols` method that collects all global symbols referenced by an operation. When a symbol table is provided, the implementation for FIR recursively walks type descriptor globals to find all transitively referenced symbols. Note that alternately this could have been implemented in different ways: - Codegen could implicitly generate such type globals as needed by changing the technique that relies on populating them during lowering (eg generate them directly in gpu.module during codegen). - This interface could attach to types instead of operations for a potentially more conservative implementation which maps all type descriptors even if the underlying implementation using it won't necessarily need such mapping. The technique chosen here is consistent with `CUFDeviceGlobal` (which walks operations inside `prepareImplicitDeviceGlobals`) and avoids conservative mapping of all type descriptors.
1 parent 0b6db77 commit 89bb99d

File tree

4 files changed

+137
-0
lines changed

4 files changed

+137
-0
lines changed

flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ struct GlobalVariableModel
6767
bool isConstant(mlir::Operation *op) const;
6868
};
6969

70+
template <typename Op>
71+
struct IndirectGlobalAccessModel
72+
: public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
73+
IndirectGlobalAccessModel<Op>, Op> {
74+
void getReferencedSymbols(mlir::Operation *op,
75+
llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
76+
mlir::SymbolTable *symbolTable) const;
77+
};
78+
7079
} // namespace fir::acc
7180

7281
#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_

flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#include "flang/Optimizer/Dialect/FIROps.h"
1616
#include "flang/Optimizer/HLFIR/HLFIROps.h"
17+
#include "flang/Optimizer/Support/InternalNames.h"
18+
#include "mlir/IR/SymbolTable.h"
19+
#include "llvm/ADT/SmallSet.h"
1720

1821
namespace fir::acc {
1922

@@ -68,4 +71,97 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
6871
return globalOp.getConstant().has_value();
6972
}
7073

74+
// Helper to recursively process address-of operations in derived type
75+
// descriptors and collect all needed fir.globals.
76+
static void processAddrOfOpInDerivedTypeDescriptor(
77+
fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab,
78+
llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
79+
llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
80+
if (auto globalOp = symTab.lookup<fir::GlobalOp>(
81+
addrOfOp.getSymbol().getLeafReference().getValue())) {
82+
if (globalsSet.contains(globalOp))
83+
return;
84+
globalsSet.insert(globalOp);
85+
symbols.push_back(addrOfOp.getSymbolAttr());
86+
globalOp.walk([&](fir::AddrOfOp op) {
87+
processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols);
88+
});
89+
}
90+
}
91+
92+
// Utility to collect referenced symbols for type descriptors of derived types.
93+
// This is the common logic for operations that may require type descriptor
94+
// globals.
95+
static void collectReferencedSymbolsForType(
96+
mlir::Type ty, mlir::Operation *op,
97+
llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
98+
mlir::SymbolTable *symbolTable) {
99+
ty = fir::getDerivedType(fir::unwrapRefType(ty));
100+
101+
// Look for type descriptor globals only if it's a derived (record) type
102+
if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
103+
// If no symbol table provided, simply add the type descriptor name
104+
if (!symbolTable) {
105+
symbols.push_back(mlir::SymbolRefAttr::get(
106+
op->getContext(),
107+
fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
108+
return;
109+
}
110+
111+
// Otherwise, do full lookup and recursive processing
112+
llvm::SmallSet<mlir::Operation *, 16> globalsSet;
113+
114+
fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
115+
fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
116+
if (!globalOp)
117+
globalOp = symbolTable->lookup<fir::GlobalOp>(
118+
fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
119+
120+
if (globalOp) {
121+
globalsSet.insert(globalOp);
122+
symbols.push_back(
123+
mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName()));
124+
globalOp.walk([&](fir::AddrOfOp addrOp) {
125+
processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet,
126+
symbols);
127+
});
128+
}
129+
}
130+
}
131+
132+
template <>
133+
void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
134+
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
135+
mlir::SymbolTable *symbolTable) const {
136+
auto allocaOp = mlir::cast<fir::AllocaOp>(op);
137+
collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
138+
}
139+
140+
template <>
141+
void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
142+
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
143+
mlir::SymbolTable *symbolTable) const {
144+
auto emboxOp = mlir::cast<fir::EmboxOp>(op);
145+
collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
146+
symbolTable);
147+
}
148+
149+
template <>
150+
void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
151+
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
152+
mlir::SymbolTable *symbolTable) const {
153+
auto reboxOp = mlir::cast<fir::ReboxOp>(op);
154+
collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
155+
symbolTable);
156+
}
157+
158+
template <>
159+
void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
160+
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
161+
mlir::SymbolTable *symbolTable) const {
162+
auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
163+
collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
164+
symbolTable);
165+
}
166+
71167
} // namespace fir::acc

flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
5252

5353
fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
5454
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
55+
56+
fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
57+
*ctx);
58+
fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
59+
*ctx);
60+
fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
61+
*ctx);
62+
fir::TypeDescOp::attachInterface<
63+
IndirectGlobalAccessModel<fir::TypeDescOp>>(*ctx);
5564
});
5665

5766
// Register HLFIR operation interfaces

mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
7575
];
7676
}
7777

78+
def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> {
79+
let cppNamespace = "::mlir::acc";
80+
81+
let description = [{
82+
An interface for operations that indirectly access global symbols.
83+
This interface provides a way to query which global symbols are referenced
84+
by an operation, which is useful for tracking dependencies and performing
85+
analysis on global variable usage.
86+
87+
The symbolTable parameter is optional. If null, implementations will look up
88+
their own symbol table. This allows callers to pass a pre-existing symbol
89+
table for efficiency when querying multiple operations.
90+
}];
91+
92+
let methods = [
93+
InterfaceMethod<"Get the symbols referenced by this operation",
94+
"void",
95+
"getReferencedSymbols",
96+
(ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols,
97+
"::mlir::SymbolTable *":$symbolTable)>,
98+
];
99+
}
100+
78101
#endif // OPENACC_OPS_INTERFACES

0 commit comments

Comments
 (0)