Skip to content

Commit ca79e12

Browse files
authored
[flang][cuda] Handle implicit global in cuf kernel and nested statement (#116846)
Update the implicit global detection by looking for them in the CUF kernel and also update to a walk so nested `fir.address_of` in nested statement are also accounted for.
1 parent eff60d8 commit ca79e12

File tree

2 files changed

+115
-12
lines changed

2 files changed

+115
-12
lines changed

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,37 @@ namespace fir {
2626

2727
namespace {
2828

29+
static void processAddrOfOp(fir::AddrOfOp addrOfOp,
30+
mlir::SymbolTable &symbolTable, bool onlyConstant) {
31+
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
32+
addrOfOp.getSymbol().getRootReference().getValue())) {
33+
bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
34+
!globalOp.getDataAttr()};
35+
if (isCandidate)
36+
globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
37+
addrOfOp.getContext(), globalOp.getConstant()
38+
? cuf::DataAttribute::Constant
39+
: cuf::DataAttribute::Device));
40+
}
41+
}
42+
2943
static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
3044
mlir::SymbolTable &symbolTable,
3145
bool onlyConstant = true) {
3246
auto cudaProcAttr{
3347
funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
34-
if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host)
35-
return;
36-
for (auto addrOfOp : funcOp.getBody().getOps<fir::AddrOfOp>()) {
37-
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
38-
addrOfOp.getSymbol().getRootReference().getValue())) {
39-
bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
40-
!globalOp.getDataAttr()};
41-
if (isCandidate)
42-
globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
43-
funcOp.getContext(), globalOp.getConstant()
44-
? cuf::DataAttribute::Constant
45-
: cuf::DataAttribute::Device));
48+
if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host) {
49+
// Look for globlas in CUF KERNEL DO operations.
50+
for (auto cufKernelOp : funcOp.getBody().getOps<cuf::KernelOp>()) {
51+
cufKernelOp.walk([&](fir::AddrOfOp addrOfOp) {
52+
processAddrOfOp(addrOfOp, symbolTable, onlyConstant);
53+
});
4654
}
55+
return;
4756
}
57+
funcOp.walk([&](fir::AddrOfOp addrOfOp) {
58+
processAddrOfOp(addrOfOp, symbolTable, onlyConstant);
59+
});
4860
}
4961

5062
class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {

flang/test/Fir/CUDA/cuda-implicit-device-global.f90

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,94 @@ // Test that global used in device function are flagged with the correct
5353

5454
// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
5555
// CHECK-NOT: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
56+
57+
// -----
58+
59+
func.func @_QPsub1() {
60+
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsub1Ei"}
61+
%1:2 = hlfir.declare %0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
62+
%c1_i32 = arith.constant 1 : i32
63+
%2 = fir.convert %c1_i32 : (i32) -> index
64+
%c100_i32 = arith.constant 100 : i32
65+
%3 = fir.convert %c100_i32 : (i32) -> index
66+
%c1 = arith.constant 1 : index
67+
cuf.kernel<<<*, *>>> (%arg0 : index) = (%2 : index) to (%3 : index) step (%c1 : index) {
68+
%4 = fir.convert %arg0 : (index) -> i32
69+
fir.store %4 to %1#1 : !fir.ref<i32>
70+
%5 = fir.load %1#0 : !fir.ref<i32>
71+
%c1_i32_0 = arith.constant 1 : i32
72+
%6 = arith.cmpi eq, %5, %c1_i32_0 : i32
73+
fir.if %6 {
74+
%c6_i32 = arith.constant 6 : i32
75+
%7 = fir.address_of(@_QQclX91d13f6e74caa2f03965d7a7c6a8fdd5) : !fir.ref<!fir.char<1,10>>
76+
%8 = fir.convert %7 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
77+
%c5_i32 = arith.constant 5 : i32
78+
%9 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %8, %c5_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
79+
%10 = fir.address_of(@_QQclX5465737420504153534544) : !fir.ref<!fir.char<1,11>>
80+
%c11 = arith.constant 11 : index
81+
%11:2 = hlfir.declare %10 typeparams %c11 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQclX5465737420504153534544"} : (!fir.ref<!fir.char<1,11>>, index) -> (!fir.ref<!fir.char<1,11>>, !fir.ref<!fir.char<1,11>>)
82+
%12 = fir.convert %11#1 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
83+
%13 = fir.convert %c11 : (index) -> i64
84+
%14 = fir.call @_FortranAioOutputAscii(%9, %12, %13) fastmath<contract> : (!fir.ref<i8>, !fir.ref<i8>, i64) -> i1
85+
%15 = fir.call @_FortranAioEndIoStatement(%9) fastmath<contract> : (!fir.ref<i8>) -> i32
86+
}
87+
"fir.end"() : () -> ()
88+
}
89+
return
90+
}
91+
func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
92+
fir.global linkonce @_QQclX91d13f6e74caa2f03965d7a7c6a8fdd5 constant : !fir.char<1,10> {
93+
%0 = fir.string_lit "dummy.cuf\00"(10) : !fir.char<1,10>
94+
fir.has_value %0 : !fir.char<1,10>
95+
}
96+
func.func private @_FortranAioOutputAscii(!fir.ref<i8>, !fir.ref<i8>, i64) -> i1 attributes {fir.io, fir.runtime}
97+
fir.global linkonce @_QQclX5465737420504153534544 constant : !fir.char<1,11> {
98+
%0 = fir.string_lit "Test PASSED"(11) : !fir.char<1,11>
99+
fir.has_value %0 : !fir.char<1,11>
100+
}
101+
102+
// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,11>
103+
104+
// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
105+
// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant
106+
107+
// -----
108+
109+
func.func @_QPsub1() attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
110+
%6 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsub1Ei"}
111+
%7:2 = hlfir.declare %6 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
112+
%12 = fir.load %7#0 : !fir.ref<i32>
113+
%c1_i32 = arith.constant 1 : i32
114+
%13 = arith.cmpi eq, %12, %c1_i32 : i32
115+
fir.if %13 {
116+
%c6_i32 = arith.constant 6 : i32
117+
%14 = fir.address_of(@_QQclX91d13f6e74caa2f03965d7a7c6a8fdd5) : !fir.ref<!fir.char<1,10>>
118+
%15 = fir.convert %14 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
119+
%c3_i32 = arith.constant 3 : i32
120+
%16 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %15, %c3_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
121+
%17 = fir.address_of(@_QQclX5465737420504153534544) : !fir.ref<!fir.char<1,11>>
122+
%c11 = arith.constant 11 : index
123+
%18:2 = hlfir.declare %17 typeparams %c11 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQclX5465737420504153534544"} : (!fir.ref<!fir.char<1,11>>, index) -> (!fir.ref<!fir.char<1,11>>, !fir.ref<!fir.char<1,11>>)
124+
%19 = fir.convert %18#1 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
125+
%20 = fir.convert %c11 : (index) -> i64
126+
%21 = fir.call @_FortranAioOutputAscii(%16, %19, %20) fastmath<contract> : (!fir.ref<i8>, !fir.ref<i8>, i64) -> i1
127+
%22 = fir.call @_FortranAioEndIoStatement(%16) fastmath<contract> : (!fir.ref<i8>) -> i32
128+
}
129+
return
130+
}
131+
func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
132+
fir.global linkonce @_QQclX91d13f6e74caa2f03965d7a7c6a8fdd5 constant : !fir.char<1,10> {
133+
%0 = fir.string_lit "dummy.cuf\00"(10) : !fir.char<1,10>
134+
fir.has_value %0 : !fir.char<1,10>
135+
}
136+
func.func private @_FortranAioOutputAscii(!fir.ref<i8>, !fir.ref<i8>, i64) -> i1 attributes {fir.io, fir.runtime}
137+
fir.global linkonce @_QQclX5465737420504153534544 constant : !fir.char<1,11> {
138+
%0 = fir.string_lit "Test PASSED"(11) : !fir.char<1,11>
139+
fir.has_value %0 : !fir.char<1,11>
140+
}
141+
func.func private @_FortranAioEndIoStatement(!fir.ref<i8>) -> i32 attributes {fir.io, fir.runtime}
142+
143+
// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,11>
144+
145+
// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
146+
// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant

0 commit comments

Comments
 (0)