-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV, __LDCG with arrays #130357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Member
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) Changes__LDCA, __LDCS, __LDLU, __LDCV, __LDCG in some form take an array argument and return an array. These functions are implemented with the return array passed as the first argument. Add custom lowering to fit the implemented c function. Full diff: https://github.com/llvm/llvm-project/pull/130357.diff 3 Files Affected:
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index c82e5265970c5..3301b7195d7de 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -223,6 +223,9 @@ struct IntrinsicLibrary {
fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
void genCpuTime(llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCshift(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ template <const char *fctName, int extent>
+ fir::ExtendedValue genCUDALDXXFunc(mlir::Type,
+ llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCAssociatedCFunPtr(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCAssociatedCPtr(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ede3be074a820..bc3c6fcdd853d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -106,6 +106,34 @@ using I = IntrinsicLibrary;
/// argument is an optional variable in the current scope).
static constexpr bool handleDynamicOptional = true;
+/// TODO: Move all CUDA Fortran intrinsic hanlders into its own file similar to
+/// PPC.
+static const char __ldca_i4x4[] = "__ldca_i4x4_";
+static const char __ldca_i8x2[] = "__ldca_i8x2_";
+static const char __ldca_r2x2[] = "__ldca_r2x2_";
+static const char __ldca_r4x4[] = "__ldca_r4x4_";
+static const char __ldca_r8x2[] = "__ldca_r8x2_";
+static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
+static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
+static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
+static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
+static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
+static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
+static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
+static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
+static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
+static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
+static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
+static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
+static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
+static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
+static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
+static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
+static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
+static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
+static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
+static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
+
/// Table that drives the fir generation depending on the intrinsic or intrinsic
/// module procedure one to one mapping with Fortran arguments. If no mapping is
/// defined here for a generic intrinsic, genRuntimeCall will be called
@@ -114,6 +142,106 @@ static constexpr bool handleDynamicOptional = true;
/// argument must not be lowered by value. In which case, the lowering rules
/// should be provided for all the intrinsic arguments for completeness.
static constexpr IntrinsicHandler handlers[]{
+ {"__ldca_i4x4",
+ &I::genCUDALDXXFunc<__ldca_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_i8x2",
+ &I::genCUDALDXXFunc<__ldca_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r2x2",
+ &I::genCUDALDXXFunc<__ldca_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r4x4",
+ &I::genCUDALDXXFunc<__ldca_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r8x2",
+ &I::genCUDALDXXFunc<__ldca_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_i4x4",
+ &I::genCUDALDXXFunc<__ldcg_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_i8x2",
+ &I::genCUDALDXXFunc<__ldcg_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r2x2",
+ &I::genCUDALDXXFunc<__ldcg_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r4x4",
+ &I::genCUDALDXXFunc<__ldcg_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r8x2",
+ &I::genCUDALDXXFunc<__ldcg_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_i4x4",
+ &I::genCUDALDXXFunc<__ldcs_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_i8x2",
+ &I::genCUDALDXXFunc<__ldcs_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r2x2",
+ &I::genCUDALDXXFunc<__ldcs_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r4x4",
+ &I::genCUDALDXXFunc<__ldcs_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r8x2",
+ &I::genCUDALDXXFunc<__ldcs_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_i4x4",
+ &I::genCUDALDXXFunc<__ldcv_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_i8x2",
+ &I::genCUDALDXXFunc<__ldcv_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r2x2",
+ &I::genCUDALDXXFunc<__ldcv_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r4x4",
+ &I::genCUDALDXXFunc<__ldcv_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r8x2",
+ &I::genCUDALDXXFunc<__ldcv_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_i4x4",
+ &I::genCUDALDXXFunc<__ldlu_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_i8x2",
+ &I::genCUDALDXXFunc<__ldlu_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r2x2",
+ &I::genCUDALDXXFunc<__ldlu_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r4x4",
+ &I::genCUDALDXXFunc<__ldlu_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r8x2",
+ &I::genCUDALDXXFunc<__ldlu_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
{"abort", &I::genAbort},
{"abs", &I::genAbs},
{"achar", &I::genChar},
@@ -3544,6 +3672,29 @@ IntrinsicLibrary::genCshift(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT");
}
+// __LDCA, __LDCS, __LDLU, __LDCV
+template <const char *fctName, int extent>
+fir::ExtendedValue
+IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::Type resTy = fir::SequenceType::get(extent, resultType);
+ mlir::Value arg = fir::getBase(args[0]);
+ mlir::Value res = builder.create<fir::AllocaOp>(loc, resTy);
+ if (mlir::isa<fir::BaseBoxType>(arg.getType()))
+ arg = builder.create<fir::BoxAddrOp>(loc, arg);
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(arg.getContext(), {resTy, resTy}, {});
+ auto funcOp = builder.createFunction(loc, fctName, ftype);
+ llvm::SmallVector<mlir::Value> funcArgs;
+ funcArgs.push_back(res);
+ funcArgs.push_back(arg);
+ builder.create<fir::CallOp>(loc, funcOp, funcArgs);
+ mlir::Value ext =
+ builder.createIntegerConstant(loc, builder.getIndexType(), extent);
+ return fir::ArrayBoxValue(res, {ext});
+}
+
// DATE_AND_TIME
void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 4 && "date_and_time has 4 args");
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 5f39f78f8ecae..02c94235a354f 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -210,10 +210,11 @@ attributes(global) subroutine __ldXXi4(b)
end
! CHECK-LABEL: func.func @_QP__ldxxi4
-! CHECK: __ldca_i4x4
-! CHECK: __ldcg_i4x4
-! CHECK: __ldcs_i4x4
-! CHECK: __ldlu_i4x4
+! CHECK: fir.call @__ldca_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcg_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcs_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldlu_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcv_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
attributes(global) subroutine __ldXXi8(b)
integer(8), device :: b(*)
@@ -226,10 +227,11 @@ attributes(global) subroutine __ldXXi8(b)
end
! CHECK-LABEL: func.func @_QP__ldxxi8
-! CHECK: __ldca_i8x2
-! CHECK: __ldcg_i8x2
-! CHECK: __ldcs_i8x2
-! CHECK: __ldlu_i8x2
+! CHECK: fir.call @__ldca_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcg_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcs_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldlu_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcv_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
attributes(global) subroutine __ldXXr4(b)
real, device :: b(*)
@@ -242,10 +244,11 @@ attributes(global) subroutine __ldXXr4(b)
end
! CHECK-LABEL: func.func @_QP__ldxxr4
-! CHECK: __ldca_r4x4
-! CHECK: __ldcg_r4x4
-! CHECK: __ldcs_r4x4
-! CHECK: __ldlu_r4x4
+! CHECK: fir.call @__ldca_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcg_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcs_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldlu_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcv_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
attributes(global) subroutine __ldXXr2(b)
real(2), device :: b(*)
@@ -258,10 +261,11 @@ attributes(global) subroutine __ldXXr2(b)
end
! CHECK-LABEL: func.func @_QP__ldxxr2
-! CHECK: __ldca_r2x2
-! CHECK: __ldcg_r2x2
-! CHECK: __ldcs_r2x2
-! CHECK: __ldlu_r2x2
+! CHECK: fir.call @__ldca_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcg_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcs_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldlu_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcv_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
attributes(global) subroutine __ldXXr8(b)
real(8), device :: b(*)
@@ -274,7 +278,8 @@ attributes(global) subroutine __ldXXr8(b)
end
! CHECK-LABEL: func.func @_QP__ldxxr8
-! CHECK: __ldca_r8x2
-! CHECK: __ldcg_r8x2
-! CHECK: __ldcs_r8x2
-! CHECK: __ldlu_r8x2
+! CHECK: fir.call @__ldca_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcg_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcs_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
|
wangzpgi
reviewed
Mar 7, 2025
wangzpgi
approved these changes
Mar 7, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
__LDCA, __LDCS, __LDLU, __LDCV, __LDCG in some form take an array argument and return an array. These functions are implemented with the return array passed as the first argument. Add custom lowering to fit the implemented c function.