Skip to content

Commit ea04148

Browse files
authored
[flang][cuda] Extend implicit global handling to any type descriptor (#119769)
Relax the check to also handle other type descriptor globals.
1 parent 80cd9e4 commit ea04148

File tree

2 files changed

+83
-83
lines changed

2 files changed

+83
-83
lines changed

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@ static constexpr llvm::StringRef builtinPrefix = "_QM__fortran_builtins";
3232

3333
static void processAddrOfOp(fir::AddrOfOp addrOfOp,
3434
mlir::SymbolTable &symbolTable,
35-
llvm::DenseSet<fir::GlobalOp> &candidates) {
35+
llvm::DenseSet<fir::GlobalOp> &candidates,
36+
bool recurseInGlobal) {
3637
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
3738
addrOfOp.getSymbol().getRootReference().getValue())) {
3839
// TO DO: limit candidates to non-scalars. Scalars appear to have been
3940
// folded in already.
4041
if (globalOp.getConstant()) {
41-
// Limit recursion to builtin global for now.
42-
if (globalOp.getSymName().starts_with(builtinPrefix)) {
42+
if (recurseInGlobal)
4343
globalOp.walk([&](fir::AddrOfOp op) {
44-
processAddrOfOp(op, symbolTable, candidates);
44+
processAddrOfOp(op, symbolTable, candidates, recurseInGlobal);
4545
});
46-
}
4746
candidates.insert(globalOp);
4847
}
4948
}
@@ -52,18 +51,18 @@ static void processAddrOfOp(fir::AddrOfOp addrOfOp,
5251
static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
5352
llvm::DenseSet<fir::GlobalOp> &candidates) {
5453
if (auto recTy = mlir::dyn_cast<fir::RecordType>(
55-
fir::unwrapRefType(emboxOp.getMemref().getType())))
56-
// Only look at builtin record type.
57-
if (recTy.getName().starts_with(builtinPrefix))
58-
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
59-
fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
60-
if (!candidates.contains(globalOp)) {
61-
globalOp.walk([&](fir::AddrOfOp op) {
62-
processAddrOfOp(op, symbolTable, candidates);
63-
});
64-
candidates.insert(globalOp);
65-
}
54+
fir::unwrapRefType(emboxOp.getMemref().getType()))) {
55+
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
56+
fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
57+
if (!candidates.contains(globalOp)) {
58+
globalOp.walk([&](fir::AddrOfOp op) {
59+
processAddrOfOp(op, symbolTable, candidates,
60+
/*recurseInGlobal=*/true);
61+
});
62+
candidates.insert(globalOp);
6663
}
64+
}
65+
}
6766
}
6867

6968
static void
@@ -74,7 +73,7 @@ prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
7473
funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
7574
if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
7675
funcOp.walk([&](fir::AddrOfOp op) {
77-
processAddrOfOp(op, symbolTable, candidates);
76+
processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false);
7877
});
7978
funcOp.walk(
8079
[&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); });
@@ -97,7 +96,8 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
9796
});
9897
mod.walk([&](cuf::KernelOp kernelOp) {
9998
kernelOp.walk([&](fir::AddrOfOp addrOfOp) {
100-
processAddrOfOp(addrOfOp, symTable, candidates);
99+
processAddrOfOp(addrOfOp, symTable, candidates,
100+
/*recurseInGlobal=*/false);
101101
});
102102
});
103103

0 commit comments

Comments
 (0)