1212#include " flang/Optimizer/HLFIR/HLFIROps.h"
1313#include " mlir/Dialect/Func/IR/FuncOps.h"
1414#include " mlir/Dialect/LLVMIR/NVVMDialect.h"
15+ #include " mlir/Dialect/OpenACC/OpenACC.h"
1516
1617// / Retrieve or create the CUDA Fortran GPU module in the give in \p mod.
1718mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule (mlir::ModuleOp mod,
@@ -31,32 +32,47 @@ mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
3132 return gpuMod;
3233}
3334
34- bool cuf::isInCUDADeviceContext (mlir::Operation *op) {
35- if (!op)
35+ bool cuf::isCUDADeviceContext (mlir::Operation *op) {
36+ if (!op || !op-> getParentRegion () )
3637 return false ;
37- if (op->getParentOfType <cuf::KernelOp>() ||
38- op->getParentOfType <mlir::gpu::GPUFuncOp>())
38+ return isCUDADeviceContext (*op->getParentRegion ());
39+ }
40+
41+ // Check if the insertion point is currently in a device context. HostDevice
42+ // subprogram are not considered fully device context so it will return false
43+ // for it.
44+ // If the insertion point is inside an OpenACC region op, it is considered
45+ // device context.
46+ bool cuf::isCUDADeviceContext (mlir::Region ®ion) {
47+ if (region.getParentOfType <cuf::KernelOp>())
48+ return true ;
49+ if (region.getParentOfType <mlir::acc::ComputeRegionOpInterface>())
3950 return true ;
40- if (auto funcOp = op->getParentOfType <mlir::func::FuncOp>()) {
41- if (auto cudaProcAttr = funcOp->getAttrOfType <cuf::ProcAttributeAttr>(
42- cuf::getProcAttrName ())) {
43- return cudaProcAttr.getValue () != cuf::ProcAttribute::Host;
51+ if (auto funcOp = region.getParentOfType <mlir::func::FuncOp>()) {
52+ if (auto cudaProcAttr =
53+ funcOp.getOperation ()->getAttrOfType <cuf::ProcAttributeAttr>(
54+ cuf::getProcAttrName ())) {
55+ return cudaProcAttr.getValue () != cuf::ProcAttribute::Host &&
56+ cudaProcAttr.getValue () != cuf::ProcAttribute::HostDevice;
4457 }
4558 }
4659 return false ;
4760}
4861
49- bool cuf::isRegisteredDeviceGlobal (fir::GlobalOp op) {
50- if (op.getConstant ())
51- return false ;
52- auto attr = op.getDataAttr ();
62+ bool cuf::isRegisteredDeviceAttr (std::optional<cuf::DataAttribute> attr) {
5363 if (attr && (*attr == cuf::DataAttribute::Device ||
5464 *attr == cuf::DataAttribute::Managed ||
5565 *attr == cuf::DataAttribute::Constant))
5666 return true ;
5767 return false ;
5868}
5969
70+ bool cuf::isRegisteredDeviceGlobal (fir::GlobalOp op) {
71+ if (op.getConstant ())
72+ return false ;
73+ return isRegisteredDeviceAttr (op.getDataAttr ());
74+ }
75+
6076void cuf::genPointerSync (const mlir::Value box, fir::FirOpBuilder &builder) {
6177 if (auto declareOp = box.getDefiningOp <hlfir::DeclareOp>()) {
6278 if (auto addrOfOp = declareOp.getMemref ().getDefiningOp <fir::AddrOfOp>()) {
0 commit comments