|
14 | 14 | #include "flang/Optimizer/Transforms/CUFCommon.h" |
15 | 15 | #include "flang/Runtime/CUDA/common.h" |
16 | 16 | #include "flang/Runtime/allocatable.h" |
| 17 | +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
17 | 18 | #include "mlir/IR/SymbolTable.h" |
18 | 19 | #include "mlir/Pass/Pass.h" |
19 | 20 | #include "mlir/Transforms/DialectConversion.h" |
@@ -62,27 +63,26 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> { |
62 | 63 |
|
63 | 64 | // Copying the device global variable into the gpu module |
64 | 65 | mlir::SymbolTable parentSymTable(mod); |
65 | | - auto gpuMod = |
66 | | - parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName); |
67 | | - if (gpuMod) { |
68 | | - mlir::SymbolTable gpuSymTable(gpuMod); |
69 | | - for (auto globalOp : mod.getOps<fir::GlobalOp>()) { |
70 | | - auto attr = globalOp.getDataAttrAttr(); |
71 | | - if (!attr) |
72 | | - continue; |
73 | | - switch (attr.getValue()) { |
74 | | - case cuf::DataAttribute::Device: |
75 | | - case cuf::DataAttribute::Constant: |
76 | | - case cuf::DataAttribute::Managed: { |
77 | | - auto globalName{globalOp.getSymbol().getValue()}; |
78 | | - if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) { |
79 | | - break; |
80 | | - } |
81 | | - gpuSymTable.insert(globalOp->clone()); |
82 | | - } break; |
83 | | - default: |
| 66 | + auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable); |
| 67 | + if (!gpuMod) |
| 68 | + return signalPassFailure(); |
| 69 | + mlir::SymbolTable gpuSymTable(gpuMod); |
| 70 | + for (auto globalOp : mod.getOps<fir::GlobalOp>()) { |
| 71 | + auto attr = globalOp.getDataAttrAttr(); |
| 72 | + if (!attr) |
| 73 | + continue; |
| 74 | + switch (attr.getValue()) { |
| 75 | + case cuf::DataAttribute::Device: |
| 76 | + case cuf::DataAttribute::Constant: |
| 77 | + case cuf::DataAttribute::Managed: { |
| 78 | + auto globalName{globalOp.getSymbol().getValue()}; |
| 79 | + if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) { |
84 | 80 | break; |
85 | 81 | } |
| 82 | + gpuSymTable.insert(globalOp->clone()); |
| 83 | + } break; |
| 84 | + default: |
| 85 | + break; |
86 | 86 | } |
87 | 87 | } |
88 | 88 | } |
|
0 commit comments