Skip to content

Commit 067ce5c

Browse files
authored
[flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal pass (llvm#114468)
Make the pass functional if gpu module was not created yet.
1 parent cf0b6cc commit 067ce5c

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def CUFDeviceGlobal :
432432
Pass<"cuf-device-global", "mlir::ModuleOp"> {
433433
let summary = "Flag globals used in device function with data attribute";
434434
let dependentDialects = [
435-
"cuf::CUFDialect"
435+
"cuf::CUFDialect", "mlir::gpu::GPUDialect", "mlir::NVVM::NVVMDialect"
436436
];
437437
}
438438

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "flang/Optimizer/Transforms/CUFCommon.h"
1515
#include "flang/Runtime/CUDA/common.h"
1616
#include "flang/Runtime/allocatable.h"
17+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1718
#include "mlir/IR/SymbolTable.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
@@ -62,27 +63,26 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
6263

6364
// Copying the device global variable into the gpu module
6465
mlir::SymbolTable parentSymTable(mod);
65-
auto gpuMod =
66-
parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
67-
if (gpuMod) {
68-
mlir::SymbolTable gpuSymTable(gpuMod);
69-
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
70-
auto attr = globalOp.getDataAttrAttr();
71-
if (!attr)
72-
continue;
73-
switch (attr.getValue()) {
74-
case cuf::DataAttribute::Device:
75-
case cuf::DataAttribute::Constant:
76-
case cuf::DataAttribute::Managed: {
77-
auto globalName{globalOp.getSymbol().getValue()};
78-
if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
79-
break;
80-
}
81-
gpuSymTable.insert(globalOp->clone());
82-
} break;
83-
default:
66+
auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
67+
if (!gpuMod)
68+
return signalPassFailure();
69+
mlir::SymbolTable gpuSymTable(gpuMod);
70+
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
71+
auto attr = globalOp.getDataAttrAttr();
72+
if (!attr)
73+
continue;
74+
switch (attr.getValue()) {
75+
case cuf::DataAttribute::Device:
76+
case cuf::DataAttribute::Constant:
77+
case cuf::DataAttribute::Managed: {
78+
auto globalName{globalOp.getSymbol().getValue()};
79+
if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
8480
break;
8581
}
82+
gpuSymTable.insert(globalOp->clone());
83+
} break;
84+
default:
85+
break;
8686
}
8787
}
8888
}

flang/test/Fir/CUDA/cuda-implicit-device-global.f90

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ // Test that global used in device function are flagged with the correct
2525
// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
2626
// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
2727

28+
// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
29+
// CHECK: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
30+
2831
// -----
2932

3033
func.func @_QMdataPsetvalue() {
@@ -47,3 +50,6 @@ // Test that global used in device function are flagged with the correct
4750
// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
4851
// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
4952
// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>
53+
54+
// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
55+
// CHECK-NOT: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a

0 commit comments

Comments
 (0)