diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp index a4761f24f16d7..dc39be8574f84 100644 --- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp +++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp @@ -11,6 +11,7 @@ #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Transforms/CUFCommon.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/allocatable.h" #include "mlir/IR/SymbolTable.h" @@ -58,6 +59,32 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase { prepareImplicitDeviceGlobals(funcOp, symTable); return mlir::WalkResult::advance(); }); + + // Copying the device global variable into the gpu module + mlir::SymbolTable parentSymTable(mod); + auto gpuMod = + parentSymTable.lookup(cudaDeviceModuleName); + if (gpuMod) { + mlir::SymbolTable gpuSymTable(gpuMod); + for (auto globalOp : mod.getOps()) { + auto attr = globalOp.getDataAttrAttr(); + if (!attr) + continue; + switch (attr.getValue()) { + case cuf::DataAttribute::Device: + case cuf::DataAttribute::Constant: + case cuf::DataAttribute::Managed: { + auto globalName{globalOp.getSymbol().getValue()}; + if (gpuSymTable.lookup(globalName)) { + break; + } + gpuSymTable.insert(globalOp->clone()); + } break; + default: + break; + } + } + } } }; } // namespace diff --git a/flang/test/Fir/CUDA/cuda-device-global.f90 b/flang/test/Fir/CUDA/cuda-device-global.f90 new file mode 100644 index 0000000000000..c83a938d5af21 --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-device-global.f90 @@ -0,0 +1,13 @@ + +// RUN: fir-opt --split-input-file --cuf-device-global %s | FileCheck %s + + +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module} { + fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda} : !fir.array<5xi32> + + gpu.module @cuda_device_mod [#nvvm.target] { + } +} + +// CHECK: gpu.module @cuda_device_mod [#nvvm.target] +// CHECK-NEXT: fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda} : !fir.array<5xi32>