Skip to content

Commit 6a7754f

Browse files
authored
[flang][cuda] Set address space for constant variables (llvm#163430)
Set the correct address space for constant variables. Address of operation will introduce an address cast.
1 parent 975fba1 commit 6a7754f

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,17 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
176176
llvm::LogicalResult
177177
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
178178
mlir::ConversionPatternRewriter &rewriter) const override {
179+
180+
if (auto gpuMod = addr->getParentOfType<mlir::gpu::GPUModuleOp>()) {
181+
auto global = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
182+
assert(global && "Expect global in gpu module");
183+
replaceWithAddrOfOrASCast(rewriter, addr->getLoc(), global.getAddrSpace(),
184+
getProgramAddressSpace(rewriter),
185+
global.getSymName(),
186+
convertType(addr.getType()), addr);
187+
return mlir::success();
188+
}
189+
179190
auto global = addr->getParentOfType<mlir::ModuleOp>()
180191
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
181192
replaceWithAddrOfOrASCast(
@@ -3231,7 +3242,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
32313242

32323243
if (global.getDataAttr() &&
32333244
*global.getDataAttr() == cuf::DataAttribute::Constant)
3234-
TODO(global.getLoc(), "CUDA Fortran CONSTANT variable code generation");
3245+
g.setAddrSpace(
3246+
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant));
32353247

32363248
rewriter.eraseOp(global);
32373249
return mlir::success();

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,31 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
284284
// CHECK-LABEL: llvm.func @_QQxxx()
285285
// CHECK: llvm.alloca %{{.*}} x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<2 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
286286
// CHECK-NOT: llvm.call @_FortranACUFAllocDescriptor
287+
288+
// -----
289+
290+
module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
291+
gpu.module @cuda_device_mod {
292+
fir.global @_QMkernelsEinitial_val {data_attr = #cuf.cuda<constant>} : i32 {
293+
%0 = fir.zero_bits i32
294+
fir.has_value %0 : i32
295+
}
296+
gpu.func @_QMkernelsPassign(%arg0: !fir.ref<!fir.array<?xi32>>) kernel {
297+
%c-1 = arith.constant -1 : index
298+
%c1_i32 = arith.constant 1 : i32
299+
%0 = arith.constant 1 : i32
300+
%1 = arith.addi %0, %c1_i32 : i32
301+
%2 = fir.address_of(@_QMkernelsEinitial_val) : !fir.ref<i32>
302+
%4 = fir.load %2 : !fir.ref<i32>
303+
%5 = fir.convert %1 : (i32) -> i64
304+
%6 = fircg.ext_array_coor %arg0(%c-1)<%5> : (!fir.ref<!fir.array<?xi32>>, index, i64) -> !fir.ref<i32>
305+
fir.store %4 to %6 : !fir.ref<i32>
306+
gpu.return
307+
}
308+
}
309+
}
310+
311+
// CHECK: llvm.mlir.global external @_QMkernelsEinitial_val() {addr_space = 4 : i32} : i32
312+
// CHECK-LABEL: gpu.func @_QMkernelsPassign
313+
// CHECK: %[[ADDROF:.*]] = llvm.mlir.addressof @_QMkernelsEinitial_val : !llvm.ptr<4>
314+
// CHECK: %{{.*}} = llvm.addrspacecast %[[ADDROF]] : !llvm.ptr<4> to !llvm.ptr

0 commit comments

Comments
 (0)