Skip to content

Commit 07e2623

Browse files
committed
[flang][cuda] Copying device globals in the gpu module
1 parent 7131569 commit 07e2623

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Optimizer/Dialect/FIRDialect.h"
1212
#include "flang/Optimizer/Dialect/FIROps.h"
1313
#include "flang/Optimizer/HLFIR/HLFIROps.h"
14+
#include "flang/Optimizer/Transforms/CUFCommon.h"
1415
#include "flang/Runtime/CUDA/common.h"
1516
#include "flang/Runtime/allocatable.h"
1617
#include "mlir/IR/SymbolTable.h"
@@ -58,6 +59,31 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
5859
prepareImplicitDeviceGlobals(funcOp, symTable);
5960
return mlir::WalkResult::advance();
6061
});
62+
63+
mlir::SymbolTable parentSymTable(mod);
64+
auto gpuMod =
65+
parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
66+
if (gpuMod) {
67+
mlir::SymbolTable gpuSymTable(gpuMod);
68+
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
69+
auto attr = globalOp.getDataAttrAttr();
70+
if (!attr)
71+
continue;
72+
switch (attr.getValue()) {
73+
case cuf::DataAttribute::Device:
74+
case cuf::DataAttribute::Constant:
75+
case cuf::DataAttribute::Managed: {
76+
auto globalName{globalOp.getSymbol().getValue()};
77+
if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
78+
break;
79+
}
80+
gpuSymTable.insert(globalOp->clone());
81+
} break;
82+
default:
83+
break;
84+
}
85+
}
86+
}
6187
}
6288
};
6389
} // namespace
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
// RUN: fir-opt --split-input-file --cuf-device-global %s | FileCheck %s
3+
4+
5+
// -----// IR Dump After CUFLaunchToGPU (cuf-fir-launch-to-gpu) //----- //
6+
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module} {
7+
fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
8+
9+
gpu.module @cuda_device_mod [#nvvm.target] {
10+
}
11+
}
12+
13+
// CHECK: gpu.module @cuda_device_mod [#nvvm.target]
14+
// CHECK-NEXT: fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>

0 commit comments

Comments
 (0)