From 5580f49b455b5edad144620fc319086d1ff6929c Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Mon, 30 Jun 2025 13:28:02 -0700 Subject: [PATCH] [flang][cuda] Bring PARAMETER arrays into the GPU module --- flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp | 10 +++++++++- flang/test/Fir/CUDA/cuda-device-global.f90 | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp index 328e2374115b0..bfb0daeacb8c3 100644 --- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp +++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp @@ -113,8 +113,16 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase { return signalPassFailure(); mlir::SymbolTable gpuSymTable(gpuMod); for (auto globalOp : mod.getOps()) { - if (cuf::isRegisteredDeviceGlobal(globalOp)) + if (cuf::isRegisteredDeviceGlobal(globalOp)) { candidates.insert(globalOp); + } else if (globalOp.getConstant() && + mlir::isa( + fir::unwrapRefType(globalOp.resultType()))) { + mlir::Attribute initAttr = + globalOp.getInitVal().value_or(mlir::Attribute()); + if (initAttr && mlir::dyn_cast(initAttr)) + candidates.insert(globalOp); + } } for (auto globalOp : candidates) { auto globalName{globalOp.getSymbol().getValue()}; diff --git a/flang/test/Fir/CUDA/cuda-device-global.f90 b/flang/test/Fir/CUDA/cuda-device-global.f90 index 8cac643b27c34..4c634513745fd 100644 --- a/flang/test/Fir/CUDA/cuda-device-global.f90 +++ b/flang/test/Fir/CUDA/cuda-device-global.f90 @@ -11,3 +11,16 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.conta // CHECK: gpu.module @cuda_device_mo // CHECK-NEXT: fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda} : !fir.array<5xi32> + +// ----- + +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module} { + fir.global @_QMm1ECb(dense<[90, 100, 110]> : tensor<3xi32>) constant : !fir.array<3xi32> + fir.global @_QMm2ECc(dense<[100, 200, 300]> : tensor<3xi32>) constant : !fir.array<3xi32> +} + +// CHECK: fir.global @_QMm1ECb +// CHECK: fir.global @_QMm2ECc +// CHECK: gpu.module @cuda_device_mod +// CHECK-DAG: fir.global @_QMm2ECc +// CHECK-DAG: fir.global @_QMm1ECb