1818#include " mlir/IR/SymbolTable.h"
1919#include " mlir/Pass/Pass.h"
2020#include " mlir/Transforms/DialectConversion.h"
21+ #include " llvm/ADT/DenseSet.h"
2122
2223namespace fir {
2324#define GEN_PASS_DEF_CUFDEVICEGLOBAL
@@ -27,36 +28,30 @@ namespace fir {
2728namespace {
2829
2930static void processAddrOfOp (fir::AddrOfOp addrOfOp,
30- mlir::SymbolTable &symbolTable, bool onlyConstant) {
31+ mlir::SymbolTable &symbolTable,
32+ llvm::DenseSet<fir::GlobalOp> &candidates) {
3133 if (auto globalOp = symbolTable.lookup <fir::GlobalOp>(
3234 addrOfOp.getSymbol ().getRootReference ().getValue ())) {
33- bool isCandidate{(onlyConstant ? globalOp.getConstant () : true ) &&
34- !globalOp.getDataAttr ()};
35- if (isCandidate)
36- globalOp.setDataAttrAttr (cuf::DataAttributeAttr::get (
37- addrOfOp.getContext (), globalOp.getConstant ()
38- ? cuf::DataAttribute::Constant
39- : cuf::DataAttribute::Device));
35+ // TO DO: limit candidates to non-scalars. Scalars appear to have been
36+ // folded in already.
37+ if (globalOp.getConstant ()) {
38+ candidates.insert (globalOp);
39+ }
4040 }
4141}
4242
43- static void prepareImplicitDeviceGlobals (mlir::func::FuncOp funcOp,
44- mlir::SymbolTable &symbolTable,
45- bool onlyConstant = true ) {
43+ static void
44+ prepareImplicitDeviceGlobals (mlir::func::FuncOp funcOp,
45+ mlir::SymbolTable &symbolTable,
46+ llvm::DenseSet<fir::GlobalOp> &candidates) {
47+
4648 auto cudaProcAttr{
4749 funcOp->getAttrOfType <cuf::ProcAttributeAttr>(cuf::getProcAttrName ())};
48- if (!cudaProcAttr || cudaProcAttr.getValue () == cuf::ProcAttribute::Host) {
49- // Look for globlas in CUF KERNEL DO operations.
50- for (auto cufKernelOp : funcOp.getBody ().getOps <cuf::KernelOp>()) {
51- cufKernelOp.walk ([&](fir::AddrOfOp addrOfOp) {
52- processAddrOfOp (addrOfOp, symbolTable, onlyConstant);
53- });
54- }
55- return ;
50+ if (cudaProcAttr && cudaProcAttr.getValue () != cuf::ProcAttribute::Host) {
51+ funcOp.walk ([&](fir::AddrOfOp addrOfOp) {
52+ processAddrOfOp (addrOfOp, symbolTable, candidates);
53+ });
5654 }
57- funcOp.walk ([&](fir::AddrOfOp addrOfOp) {
58- processAddrOfOp (addrOfOp, symbolTable, onlyConstant);
59- });
6055}
6156
6257class CUFDeviceGlobal : public fir ::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
@@ -67,9 +62,10 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
6762 if (!mod)
6863 return signalPassFailure ();
6964
65+ llvm::DenseSet<fir::GlobalOp> candidates;
7066 mlir::SymbolTable symTable (mod);
7167 mod.walk ([&](mlir::func::FuncOp funcOp) {
72- prepareImplicitDeviceGlobals (funcOp, symTable);
68+ prepareImplicitDeviceGlobals (funcOp, symTable, candidates );
7369 return mlir::WalkResult::advance ();
7470 });
7571
@@ -80,22 +76,15 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
8076 return signalPassFailure ();
8177 mlir::SymbolTable gpuSymTable (gpuMod);
8278 for (auto globalOp : mod.getOps <fir::GlobalOp>()) {
83- auto attr = globalOp.getDataAttrAttr ();
84- if (!attr)
85- continue ;
86- switch (attr.getValue ()) {
87- case cuf::DataAttribute::Device:
88- case cuf::DataAttribute::Constant:
89- case cuf::DataAttribute::Managed: {
90- auto globalName{globalOp.getSymbol ().getValue ()};
91- if (gpuSymTable.lookup <fir::GlobalOp>(globalName)) {
92- break ;
93- }
94- gpuSymTable.insert (globalOp->clone ());
95- } break ;
96- default :
79+ if (cuf::isRegisteredDeviceGlobal (globalOp))
80+ candidates.insert (globalOp);
81+ }
82+ for (auto globalOp : candidates) {
83+ auto globalName{globalOp.getSymbol ().getValue ()};
84+ if (gpuSymTable.lookup <fir::GlobalOp>(globalName)) {
9785 break ;
9886 }
87+ gpuSymTable.insert (globalOp->clone ());
9988 }
10089 }
10190};
0 commit comments