diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index a41f0f348f27a..d89713a9fc0b9 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -432,7 +432,7 @@ def CUFDeviceGlobal : Pass<"cuf-device-global", "mlir::ModuleOp"> { let summary = "Flag globals used in device function with data attribute"; let dependentDialects = [ - "cuf::CUFDialect" + "cuf::CUFDialect", "mlir::gpu::GPUDialect", "mlir::NVVM::NVVMDialect" ]; } diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp index dc39be8574f84..a69b47ff74391 100644 --- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp +++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp @@ -14,6 +14,7 @@ #include "flang/Optimizer/Transforms/CUFCommon.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/allocatable.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -62,27 +63,26 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase { // Copying the device global variable into the gpu module mlir::SymbolTable parentSymTable(mod); - auto gpuMod = - parentSymTable.lookup(cudaDeviceModuleName); - if (gpuMod) { - mlir::SymbolTable gpuSymTable(gpuMod); - for (auto globalOp : mod.getOps()) { - auto attr = globalOp.getDataAttrAttr(); - if (!attr) - continue; - switch (attr.getValue()) { - case cuf::DataAttribute::Device: - case cuf::DataAttribute::Constant: - case cuf::DataAttribute::Managed: { - auto globalName{globalOp.getSymbol().getValue()}; - if (gpuSymTable.lookup(globalName)) { - break; - } - gpuSymTable.insert(globalOp->clone()); - } break; - default: + auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable); + if (!gpuMod) + return signalPassFailure(); + mlir::SymbolTable gpuSymTable(gpuMod); + for (auto globalOp : mod.getOps()) { + auto attr = globalOp.getDataAttrAttr(); + if (!attr) + continue; + switch (attr.getValue()) { + case cuf::DataAttribute::Device: + case cuf::DataAttribute::Constant: + case cuf::DataAttribute::Managed: { + auto globalName{globalOp.getSymbol().getValue()}; + if (gpuSymTable.lookup(globalName)) { break; } + gpuSymTable.insert(globalOp->clone()); + } break; + default: + break; } } } diff --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 index 82a0c5948d9cb..18b56a491cd65 100644 --- a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 +++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 @@ -25,6 +25,9 @@ // Test that global used in device function are flagged with the correct // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath : (i32, !fir.ref, i32) -> !fir.ref // CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda} constant : !fir.char<1,32> +// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target] +// CHECK: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a + // ----- func.func @_QMdataPsetvalue() { @@ -47,3 +50,6 @@ // Test that global used in device function are flagged with the correct // CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref>) -> !fir.ref // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath : (i32, !fir.ref, i32) -> !fir.ref // CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32> + +// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target] +// CHECK-NOT: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a