Skip to content

Commit 0048036

Browse files
committed
[flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV, __LDCG with arrays
1 parent 78631ac commit 0048036

File tree

3 files changed

+179
-20
lines changed

3 files changed

+179
-20
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ struct IntrinsicLibrary {
223223
fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
224224
void genCpuTime(llvm::ArrayRef<fir::ExtendedValue>);
225225
fir::ExtendedValue genCshift(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
226+
template <const char *fctName, int extent>
227+
fir::ExtendedValue genCUDALDXXFunc(mlir::Type,
228+
llvm::ArrayRef<fir::ExtendedValue>);
226229
fir::ExtendedValue genCAssociatedCFunPtr(mlir::Type,
227230
llvm::ArrayRef<fir::ExtendedValue>);
228231
fir::ExtendedValue genCAssociatedCPtr(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,34 @@ using I = IntrinsicLibrary;
106106
/// argument is an optional variable in the current scope).
107107
static constexpr bool handleDynamicOptional = true;
108108

109+
/// TODO: Move all CUDA Fortran intrinsic hanlders into its own file similar to
110+
/// PPC.
111+
static const char __ldca_i4x4[] = "__ldca_i4x4_";
112+
static const char __ldca_i8x2[] = "__ldca_i8x2_";
113+
static const char __ldca_r2x2[] = "__ldca_r2x2_";
114+
static const char __ldca_r4x4[] = "__ldca_r4x4_";
115+
static const char __ldca_r8x2[] = "__ldca_r8x2_";
116+
static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
117+
static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
118+
static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
119+
static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
120+
static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
121+
static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
122+
static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
123+
static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
124+
static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
125+
static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
126+
static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
127+
static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
128+
static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
129+
static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
130+
static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
131+
static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
132+
static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
133+
static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
134+
static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
135+
static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
136+
109137
/// Table that drives the fir generation depending on the intrinsic or intrinsic
110138
/// module procedure one to one mapping with Fortran arguments. If no mapping is
111139
/// defined here for a generic intrinsic, genRuntimeCall will be called
@@ -114,6 +142,106 @@ static constexpr bool handleDynamicOptional = true;
114142
/// argument must not be lowered by value. In which case, the lowering rules
115143
/// should be provided for all the intrinsic arguments for completeness.
116144
static constexpr IntrinsicHandler handlers[]{
145+
{"__ldca_i4x4",
146+
&I::genCUDALDXXFunc<__ldca_i4x4, 4>,
147+
{{{"a", asAddr}}},
148+
/*isElemental=*/false},
149+
{"__ldca_i8x2",
150+
&I::genCUDALDXXFunc<__ldca_i8x2, 2>,
151+
{{{"a", asAddr}}},
152+
/*isElemental=*/false},
153+
{"__ldca_r2x2",
154+
&I::genCUDALDXXFunc<__ldca_r2x2, 2>,
155+
{{{"a", asAddr}}},
156+
/*isElemental=*/false},
157+
{"__ldca_r4x4",
158+
&I::genCUDALDXXFunc<__ldca_r4x4, 4>,
159+
{{{"a", asAddr}}},
160+
/*isElemental=*/false},
161+
{"__ldca_r8x2",
162+
&I::genCUDALDXXFunc<__ldca_r8x2, 2>,
163+
{{{"a", asAddr}}},
164+
/*isElemental=*/false},
165+
{"__ldcg_i4x4",
166+
&I::genCUDALDXXFunc<__ldcg_i4x4, 4>,
167+
{{{"a", asAddr}}},
168+
/*isElemental=*/false},
169+
{"__ldcg_i8x2",
170+
&I::genCUDALDXXFunc<__ldcg_i8x2, 2>,
171+
{{{"a", asAddr}}},
172+
/*isElemental=*/false},
173+
{"__ldcg_r2x2",
174+
&I::genCUDALDXXFunc<__ldcg_r2x2, 2>,
175+
{{{"a", asAddr}}},
176+
/*isElemental=*/false},
177+
{"__ldcg_r4x4",
178+
&I::genCUDALDXXFunc<__ldcg_r4x4, 4>,
179+
{{{"a", asAddr}}},
180+
/*isElemental=*/false},
181+
{"__ldcg_r8x2",
182+
&I::genCUDALDXXFunc<__ldcg_r8x2, 2>,
183+
{{{"a", asAddr}}},
184+
/*isElemental=*/false},
185+
{"__ldcs_i4x4",
186+
&I::genCUDALDXXFunc<__ldcs_i4x4, 4>,
187+
{{{"a", asAddr}}},
188+
/*isElemental=*/false},
189+
{"__ldcs_i8x2",
190+
&I::genCUDALDXXFunc<__ldcs_i8x2, 2>,
191+
{{{"a", asAddr}}},
192+
/*isElemental=*/false},
193+
{"__ldcs_r2x2",
194+
&I::genCUDALDXXFunc<__ldcs_r2x2, 2>,
195+
{{{"a", asAddr}}},
196+
/*isElemental=*/false},
197+
{"__ldcs_r4x4",
198+
&I::genCUDALDXXFunc<__ldcs_r4x4, 4>,
199+
{{{"a", asAddr}}},
200+
/*isElemental=*/false},
201+
{"__ldcs_r8x2",
202+
&I::genCUDALDXXFunc<__ldcs_r8x2, 2>,
203+
{{{"a", asAddr}}},
204+
/*isElemental=*/false},
205+
{"__ldcv_i4x4",
206+
&I::genCUDALDXXFunc<__ldcv_i4x4, 4>,
207+
{{{"a", asAddr}}},
208+
/*isElemental=*/false},
209+
{"__ldcv_i8x2",
210+
&I::genCUDALDXXFunc<__ldcv_i8x2, 2>,
211+
{{{"a", asAddr}}},
212+
/*isElemental=*/false},
213+
{"__ldcv_r2x2",
214+
&I::genCUDALDXXFunc<__ldcv_r2x2, 2>,
215+
{{{"a", asAddr}}},
216+
/*isElemental=*/false},
217+
{"__ldcv_r4x4",
218+
&I::genCUDALDXXFunc<__ldcv_r4x4, 4>,
219+
{{{"a", asAddr}}},
220+
/*isElemental=*/false},
221+
{"__ldcv_r8x2",
222+
&I::genCUDALDXXFunc<__ldcv_r8x2, 2>,
223+
{{{"a", asAddr}}},
224+
/*isElemental=*/false},
225+
{"__ldlu_i4x4",
226+
&I::genCUDALDXXFunc<__ldlu_i4x4, 4>,
227+
{{{"a", asAddr}}},
228+
/*isElemental=*/false},
229+
{"__ldlu_i8x2",
230+
&I::genCUDALDXXFunc<__ldlu_i8x2, 2>,
231+
{{{"a", asAddr}}},
232+
/*isElemental=*/false},
233+
{"__ldlu_r2x2",
234+
&I::genCUDALDXXFunc<__ldlu_r2x2, 2>,
235+
{{{"a", asAddr}}},
236+
/*isElemental=*/false},
237+
{"__ldlu_r4x4",
238+
&I::genCUDALDXXFunc<__ldlu_r4x4, 4>,
239+
{{{"a", asAddr}}},
240+
/*isElemental=*/false},
241+
{"__ldlu_r8x2",
242+
&I::genCUDALDXXFunc<__ldlu_r8x2, 2>,
243+
{{{"a", asAddr}}},
244+
/*isElemental=*/false},
117245
{"abort", &I::genAbort},
118246
{"abs", &I::genAbs},
119247
{"achar", &I::genChar},
@@ -3544,6 +3672,29 @@ IntrinsicLibrary::genCshift(mlir::Type resultType,
35443672
return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT");
35453673
}
35463674

3675+
// __LDCA, __LDCS, __LDLU, __LDCV
3676+
template <const char *fctName, int extent>
3677+
fir::ExtendedValue
3678+
IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType,
3679+
llvm::ArrayRef<fir::ExtendedValue> args) {
3680+
assert(args.size() == 1);
3681+
mlir::Type resTy = fir::SequenceType::get(extent, resultType);
3682+
mlir::Value arg = fir::getBase(args[0]);
3683+
mlir::Value res = builder.create<fir::AllocaOp>(loc, resTy);
3684+
if (mlir::isa<fir::BaseBoxType>(arg.getType()))
3685+
arg = builder.create<fir::BoxAddrOp>(loc, arg);
3686+
mlir::FunctionType ftype =
3687+
mlir::FunctionType::get(arg.getContext(), {resTy, resTy}, {});
3688+
auto funcOp = builder.createFunction(loc, fctName, ftype);
3689+
llvm::SmallVector<mlir::Value> funcArgs;
3690+
funcArgs.push_back(res);
3691+
funcArgs.push_back(arg);
3692+
builder.create<fir::CallOp>(loc, funcOp, funcArgs);
3693+
mlir::Value ext =
3694+
builder.createIntegerConstant(loc, builder.getIndexType(), extent);
3695+
return fir::ArrayBoxValue(res, {ext});
3696+
}
3697+
35473698
// DATE_AND_TIME
35483699
void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
35493700
assert(args.size() == 4 && "date_and_time has 4 args");

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,11 @@ attributes(global) subroutine __ldXXi4(b)
210210
end
211211

212212
! CHECK-LABEL: func.func @_QP__ldxxi4
213-
! CHECK: __ldca_i4x4
214-
! CHECK: __ldcg_i4x4
215-
! CHECK: __ldcs_i4x4
216-
! CHECK: __ldlu_i4x4
213+
! CHECK: fir.call @__ldca_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
214+
! CHECK: fir.call @__ldcg_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
215+
! CHECK: fir.call @__ldcs_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
216+
! CHECK: fir.call @__ldlu_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
217+
! CHECK: fir.call @__ldcv_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
217218

218219
attributes(global) subroutine __ldXXi8(b)
219220
integer(8), device :: b(*)
@@ -226,10 +227,11 @@ attributes(global) subroutine __ldXXi8(b)
226227
end
227228

228229
! CHECK-LABEL: func.func @_QP__ldxxi8
229-
! CHECK: __ldca_i8x2
230-
! CHECK: __ldcg_i8x2
231-
! CHECK: __ldcs_i8x2
232-
! CHECK: __ldlu_i8x2
230+
! CHECK: fir.call @__ldca_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
231+
! CHECK: fir.call @__ldcg_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
232+
! CHECK: fir.call @__ldcs_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
233+
! CHECK: fir.call @__ldlu_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
234+
! CHECK: fir.call @__ldcv_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
233235

234236
attributes(global) subroutine __ldXXr4(b)
235237
real, device :: b(*)
@@ -242,10 +244,11 @@ attributes(global) subroutine __ldXXr4(b)
242244
end
243245

244246
! CHECK-LABEL: func.func @_QP__ldxxr4
245-
! CHECK: __ldca_r4x4
246-
! CHECK: __ldcg_r4x4
247-
! CHECK: __ldcs_r4x4
248-
! CHECK: __ldlu_r4x4
247+
! CHECK: fir.call @__ldca_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
248+
! CHECK: fir.call @__ldcg_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
249+
! CHECK: fir.call @__ldcs_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
250+
! CHECK: fir.call @__ldlu_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
251+
! CHECK: fir.call @__ldcv_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
249252

250253
attributes(global) subroutine __ldXXr2(b)
251254
real(2), device :: b(*)
@@ -258,10 +261,11 @@ attributes(global) subroutine __ldXXr2(b)
258261
end
259262

260263
! CHECK-LABEL: func.func @_QP__ldxxr2
261-
! CHECK: __ldca_r2x2
262-
! CHECK: __ldcg_r2x2
263-
! CHECK: __ldcs_r2x2
264-
! CHECK: __ldlu_r2x2
264+
! CHECK: fir.call @__ldca_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
265+
! CHECK: fir.call @__ldcg_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
266+
! CHECK: fir.call @__ldcs_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
267+
! CHECK: fir.call @__ldlu_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
268+
! CHECK: fir.call @__ldcv_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
265269

266270
attributes(global) subroutine __ldXXr8(b)
267271
real(8), device :: b(*)
@@ -274,7 +278,8 @@ attributes(global) subroutine __ldXXr8(b)
274278
end
275279

276280
! CHECK-LABEL: func.func @_QP__ldxxr8
277-
! CHECK: __ldca_r8x2
278-
! CHECK: __ldcg_r8x2
279-
! CHECK: __ldcs_r8x2
280-
! CHECK: __ldlu_r8x2
281+
! CHECK: fir.call @__ldca_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
282+
! CHECK: fir.call @__ldcg_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
283+
! CHECK: fir.call @__ldcs_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
284+
! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
285+
! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()

0 commit comments

Comments
 (0)