Skip to content

Commit 829e899

Browse files
authored
[flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV, __LDCG with arrays (#130357)
1 parent 8ac359b commit 829e899

File tree

3 files changed

+179
-20
lines changed

3 files changed

+179
-20
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ struct IntrinsicLibrary {
224224
fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
225225
void genCpuTime(llvm::ArrayRef<fir::ExtendedValue>);
226226
fir::ExtendedValue genCshift(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
227+
template <const char *fctName, int extent>
228+
fir::ExtendedValue genCUDALDXXFunc(mlir::Type,
229+
llvm::ArrayRef<fir::ExtendedValue>);
227230
fir::ExtendedValue genCAssociatedCFunPtr(mlir::Type,
228231
llvm::ArrayRef<fir::ExtendedValue>);
229232
fir::ExtendedValue genCAssociatedCPtr(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,34 @@ using I = IntrinsicLibrary;
106106
/// argument is an optional variable in the current scope).
107107
static constexpr bool handleDynamicOptional = true;
108108

109+
/// TODO: Move all CUDA Fortran intrinsic handlers into its own file similar to
110+
/// PPC.
111+
static const char __ldca_i4x4[] = "__ldca_i4x4_";
112+
static const char __ldca_i8x2[] = "__ldca_i8x2_";
113+
static const char __ldca_r2x2[] = "__ldca_r2x2_";
114+
static const char __ldca_r4x4[] = "__ldca_r4x4_";
115+
static const char __ldca_r8x2[] = "__ldca_r8x2_";
116+
static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
117+
static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
118+
static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
119+
static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
120+
static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
121+
static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
122+
static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
123+
static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
124+
static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
125+
static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
126+
static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
127+
static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
128+
static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
129+
static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
130+
static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
131+
static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
132+
static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
133+
static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
134+
static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
135+
static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
136+
109137
/// Table that drives the fir generation depending on the intrinsic or intrinsic
110138
/// module procedure one to one mapping with Fortran arguments. If no mapping is
111139
/// defined here for a generic intrinsic, genRuntimeCall will be called
@@ -114,6 +142,106 @@ static constexpr bool handleDynamicOptional = true;
114142
/// argument must not be lowered by value. In which case, the lowering rules
115143
/// should be provided for all the intrinsic arguments for completeness.
116144
static constexpr IntrinsicHandler handlers[]{
145+
{"__ldca_i4x4",
146+
&I::genCUDALDXXFunc<__ldca_i4x4, 4>,
147+
{{{"a", asAddr}}},
148+
/*isElemental=*/false},
149+
{"__ldca_i8x2",
150+
&I::genCUDALDXXFunc<__ldca_i8x2, 2>,
151+
{{{"a", asAddr}}},
152+
/*isElemental=*/false},
153+
{"__ldca_r2x2",
154+
&I::genCUDALDXXFunc<__ldca_r2x2, 2>,
155+
{{{"a", asAddr}}},
156+
/*isElemental=*/false},
157+
{"__ldca_r4x4",
158+
&I::genCUDALDXXFunc<__ldca_r4x4, 4>,
159+
{{{"a", asAddr}}},
160+
/*isElemental=*/false},
161+
{"__ldca_r8x2",
162+
&I::genCUDALDXXFunc<__ldca_r8x2, 2>,
163+
{{{"a", asAddr}}},
164+
/*isElemental=*/false},
165+
{"__ldcg_i4x4",
166+
&I::genCUDALDXXFunc<__ldcg_i4x4, 4>,
167+
{{{"a", asAddr}}},
168+
/*isElemental=*/false},
169+
{"__ldcg_i8x2",
170+
&I::genCUDALDXXFunc<__ldcg_i8x2, 2>,
171+
{{{"a", asAddr}}},
172+
/*isElemental=*/false},
173+
{"__ldcg_r2x2",
174+
&I::genCUDALDXXFunc<__ldcg_r2x2, 2>,
175+
{{{"a", asAddr}}},
176+
/*isElemental=*/false},
177+
{"__ldcg_r4x4",
178+
&I::genCUDALDXXFunc<__ldcg_r4x4, 4>,
179+
{{{"a", asAddr}}},
180+
/*isElemental=*/false},
181+
{"__ldcg_r8x2",
182+
&I::genCUDALDXXFunc<__ldcg_r8x2, 2>,
183+
{{{"a", asAddr}}},
184+
/*isElemental=*/false},
185+
{"__ldcs_i4x4",
186+
&I::genCUDALDXXFunc<__ldcs_i4x4, 4>,
187+
{{{"a", asAddr}}},
188+
/*isElemental=*/false},
189+
{"__ldcs_i8x2",
190+
&I::genCUDALDXXFunc<__ldcs_i8x2, 2>,
191+
{{{"a", asAddr}}},
192+
/*isElemental=*/false},
193+
{"__ldcs_r2x2",
194+
&I::genCUDALDXXFunc<__ldcs_r2x2, 2>,
195+
{{{"a", asAddr}}},
196+
/*isElemental=*/false},
197+
{"__ldcs_r4x4",
198+
&I::genCUDALDXXFunc<__ldcs_r4x4, 4>,
199+
{{{"a", asAddr}}},
200+
/*isElemental=*/false},
201+
{"__ldcs_r8x2",
202+
&I::genCUDALDXXFunc<__ldcs_r8x2, 2>,
203+
{{{"a", asAddr}}},
204+
/*isElemental=*/false},
205+
{"__ldcv_i4x4",
206+
&I::genCUDALDXXFunc<__ldcv_i4x4, 4>,
207+
{{{"a", asAddr}}},
208+
/*isElemental=*/false},
209+
{"__ldcv_i8x2",
210+
&I::genCUDALDXXFunc<__ldcv_i8x2, 2>,
211+
{{{"a", asAddr}}},
212+
/*isElemental=*/false},
213+
{"__ldcv_r2x2",
214+
&I::genCUDALDXXFunc<__ldcv_r2x2, 2>,
215+
{{{"a", asAddr}}},
216+
/*isElemental=*/false},
217+
{"__ldcv_r4x4",
218+
&I::genCUDALDXXFunc<__ldcv_r4x4, 4>,
219+
{{{"a", asAddr}}},
220+
/*isElemental=*/false},
221+
{"__ldcv_r8x2",
222+
&I::genCUDALDXXFunc<__ldcv_r8x2, 2>,
223+
{{{"a", asAddr}}},
224+
/*isElemental=*/false},
225+
{"__ldlu_i4x4",
226+
&I::genCUDALDXXFunc<__ldlu_i4x4, 4>,
227+
{{{"a", asAddr}}},
228+
/*isElemental=*/false},
229+
{"__ldlu_i8x2",
230+
&I::genCUDALDXXFunc<__ldlu_i8x2, 2>,
231+
{{{"a", asAddr}}},
232+
/*isElemental=*/false},
233+
{"__ldlu_r2x2",
234+
&I::genCUDALDXXFunc<__ldlu_r2x2, 2>,
235+
{{{"a", asAddr}}},
236+
/*isElemental=*/false},
237+
{"__ldlu_r4x4",
238+
&I::genCUDALDXXFunc<__ldlu_r4x4, 4>,
239+
{{{"a", asAddr}}},
240+
/*isElemental=*/false},
241+
{"__ldlu_r8x2",
242+
&I::genCUDALDXXFunc<__ldlu_r8x2, 2>,
243+
{{{"a", asAddr}}},
244+
/*isElemental=*/false},
117245
{"abort", &I::genAbort},
118246
{"abs", &I::genAbs},
119247
{"achar", &I::genChar},
@@ -3544,6 +3672,29 @@ IntrinsicLibrary::genCshift(mlir::Type resultType,
35443672
return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT");
35453673
}
35463674

3675+
// __LDCA, __LDCS, __LDLU, __LDCV
3676+
template <const char *fctName, int extent>
3677+
fir::ExtendedValue
3678+
IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType,
3679+
llvm::ArrayRef<fir::ExtendedValue> args) {
3680+
assert(args.size() == 1);
3681+
mlir::Type resTy = fir::SequenceType::get(extent, resultType);
3682+
mlir::Value arg = fir::getBase(args[0]);
3683+
mlir::Value res = builder.create<fir::AllocaOp>(loc, resTy);
3684+
if (mlir::isa<fir::BaseBoxType>(arg.getType()))
3685+
arg = builder.create<fir::BoxAddrOp>(loc, arg);
3686+
mlir::FunctionType ftype =
3687+
mlir::FunctionType::get(arg.getContext(), {resTy, resTy}, {});
3688+
auto funcOp = builder.createFunction(loc, fctName, ftype);
3689+
llvm::SmallVector<mlir::Value> funcArgs;
3690+
funcArgs.push_back(res);
3691+
funcArgs.push_back(arg);
3692+
builder.create<fir::CallOp>(loc, funcOp, funcArgs);
3693+
mlir::Value ext =
3694+
builder.createIntegerConstant(loc, builder.getIndexType(), extent);
3695+
return fir::ArrayBoxValue(res, {ext});
3696+
}
3697+
35473698
// DATE_AND_TIME
35483699
void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
35493700
assert(args.size() == 4 && "date_and_time has 4 args");

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,11 @@ attributes(global) subroutine __ldXXi4(b)
217217
end
218218

219219
! CHECK-LABEL: func.func @_QP__ldxxi4
220-
! CHECK: __ldca_i4x4
221-
! CHECK: __ldcg_i4x4
222-
! CHECK: __ldcs_i4x4
223-
! CHECK: __ldlu_i4x4
220+
! CHECK: fir.call @__ldca_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
221+
! CHECK: fir.call @__ldcg_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
222+
! CHECK: fir.call @__ldcs_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
223+
! CHECK: fir.call @__ldlu_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
224+
! CHECK: fir.call @__ldcv_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
224225

225226
attributes(global) subroutine __ldXXi8(b)
226227
integer(8), device :: b(*)
@@ -233,10 +234,11 @@ attributes(global) subroutine __ldXXi8(b)
233234
end
234235

235236
! CHECK-LABEL: func.func @_QP__ldxxi8
236-
! CHECK: __ldca_i8x2
237-
! CHECK: __ldcg_i8x2
238-
! CHECK: __ldcs_i8x2
239-
! CHECK: __ldlu_i8x2
237+
! CHECK: fir.call @__ldca_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
238+
! CHECK: fir.call @__ldcg_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
239+
! CHECK: fir.call @__ldcs_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
240+
! CHECK: fir.call @__ldlu_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
241+
! CHECK: fir.call @__ldcv_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
240242

241243
attributes(global) subroutine __ldXXr4(b)
242244
real, device :: b(*)
@@ -249,10 +251,11 @@ attributes(global) subroutine __ldXXr4(b)
249251
end
250252

251253
! CHECK-LABEL: func.func @_QP__ldxxr4
252-
! CHECK: __ldca_r4x4
253-
! CHECK: __ldcg_r4x4
254-
! CHECK: __ldcs_r4x4
255-
! CHECK: __ldlu_r4x4
254+
! CHECK: fir.call @__ldca_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
255+
! CHECK: fir.call @__ldcg_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
256+
! CHECK: fir.call @__ldcs_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
257+
! CHECK: fir.call @__ldlu_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
258+
! CHECK: fir.call @__ldcv_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
256259

257260
attributes(global) subroutine __ldXXr2(b)
258261
real(2), device :: b(*)
@@ -265,10 +268,11 @@ attributes(global) subroutine __ldXXr2(b)
265268
end
266269

267270
! CHECK-LABEL: func.func @_QP__ldxxr2
268-
! CHECK: __ldca_r2x2
269-
! CHECK: __ldcg_r2x2
270-
! CHECK: __ldcs_r2x2
271-
! CHECK: __ldlu_r2x2
271+
! CHECK: fir.call @__ldca_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
272+
! CHECK: fir.call @__ldcg_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
273+
! CHECK: fir.call @__ldcs_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
274+
! CHECK: fir.call @__ldlu_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
275+
! CHECK: fir.call @__ldcv_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
272276

273277
attributes(global) subroutine __ldXXr8(b)
274278
real(8), device :: b(*)
@@ -281,7 +285,8 @@ attributes(global) subroutine __ldXXr8(b)
281285
end
282286

283287
! CHECK-LABEL: func.func @_QP__ldxxr8
284-
! CHECK: __ldca_r8x2
285-
! CHECK: __ldcg_r8x2
286-
! CHECK: __ldcs_r8x2
287-
! CHECK: __ldlu_r8x2
288+
! CHECK: fir.call @__ldca_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
289+
! CHECK: fir.call @__ldcg_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
290+
! CHECK: fir.call @__ldcs_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
291+
! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
292+
! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()

0 commit comments

Comments
 (0)