-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang-rt] Add APIs to retrive base_addr and DataSizeInBytes from Descriptor. #152756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@llvm/pr-subscribers-flang-fir-hlfir Author: Chaitanya (skc7) ChangesThis PR adds below APIs to flang-rt: Full diff: https://github.com/llvm/llvm-project/pull/152756.diff 5 Files Affected:
diff --git a/flang-rt/lib/runtime/support.cpp b/flang-rt/lib/runtime/support.cpp
index 9beb46e48a11e..ffeaafaa162ea 100644
--- a/flang-rt/lib/runtime/support.cpp
+++ b/flang-rt/lib/runtime/support.cpp
@@ -48,6 +48,30 @@ void RTDEF(CopyAndUpdateDescriptor)(Descriptor &to, const Descriptor &from,
}
}
+void *RTDEF(DescriptorGetBaseAddress)(
+ const Descriptor &desc, const char *sourceFile, int sourceLine) {
+ Terminator terminator{sourceFile, sourceLine};
+ void *baseAddr = desc.raw().base_addr;
+ if (!baseAddr) {
+ terminator.Crash("Could not retrieve Descriptor's base address");
+ }
+ return baseAddr;
+}
+
+std::size_t RTDEF(DescriptorGetDataSizeInBytes)(
+ const Descriptor &desc, const char *sourceFile, int sourceLine) {
+ Terminator terminator{sourceFile, sourceLine};
+ std::size_t descElements{desc.Elements()};
+ if (!descElements) {
+ terminator.Crash("Could not retrieve Descriptor's Elements");
+ }
+ std::size_t descElementBytes{desc.ElementBytes()};
+ if (!descElementBytes) {
+ terminator.Crash("Could not retrieve Descriptor's ElementBytes");
+ }
+ return descElements * descElementBytes;
+}
+
RT_EXT_API_GROUP_END
} // extern "C"
} // namespace Fortran::runtime
diff --git a/flang-rt/unittests/Runtime/Support.cpp b/flang-rt/unittests/Runtime/Support.cpp
index 46c6805d5d238..264dde872c242 100644
--- a/flang-rt/unittests/Runtime/Support.cpp
+++ b/flang-rt/unittests/Runtime/Support.cpp
@@ -98,3 +98,26 @@ TEST(IsContiguous, Basic) {
EXPECT_TRUE(RTNAME(IsContiguousUpTo)(section, 1));
EXPECT_FALSE(RTNAME(IsContiguousUpTo)(section, 2));
}
+
+TEST(DescriptorGetBaseAddress, Basic) {
+ auto array{MakeArray<TypeCategory::Integer, 4>(
+ std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
+ void *baseAddr = RTNAME(DescriptorGetBaseAddress)(*array);
+ EXPECT_NE(baseAddr, nullptr);
+ EXPECT_EQ(baseAddr, array->raw().base_addr);
+}
+
+TEST(DescriptorGetDataSizeInBytes, Basic) {
+ // Test with a 2x3 integer*4 array
+ auto int4Array{MakeArray<TypeCategory::Integer, 4>({2, 3})};
+ EXPECT_EQ(RTNAME(DescriptorGetDataSizeInBytes)(*int4Array),
+ 6 * sizeof(std::int32_t));
+ // Test with a 1D, 5-element real*8 array
+ auto real8Array{MakeArray<TypeCategory::Real, 8>({5})};
+ EXPECT_EQ(
+ RTNAME(DescriptorGetDataSizeInBytes)(*real8Array), 5 * sizeof(double));
+ // Test with a scalar logical*1
+ auto logical1Scalar{MakeArray<TypeCategory::Logical, 1>({})};
+ EXPECT_EQ(
+ RTNAME(DescriptorGetDataSizeInBytes)(*logical1Scalar), 1 * sizeof(bool));
+}
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Support.h b/flang/include/flang/Optimizer/Builder/Runtime/Support.h
index d0a474d75d2eb..41db61c19b07e 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Support.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Support.h
@@ -31,5 +31,18 @@ void genCopyAndUpdateDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value genIsAssumedSize(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value box);
+/// Generate call to `DescriptorGetBaseAddress` runtime routine.
+mlir::Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value desc,
+ mlir::Value sourceFile,
+ mlir::Value sourceLine);
+
+/// Generate call to `DescriptorGetDataSizeInBytes` runtime routine.
+mlir::Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value desc,
+ mlir::Value sourceFile,
+ mlir::Value sourceLine);
+
} // namespace fir::runtime
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H
diff --git a/flang/include/flang/Runtime/support.h b/flang/include/flang/Runtime/support.h
index 8a345bee7f867..5ebe6c6406a01 100644
--- a/flang/include/flang/Runtime/support.h
+++ b/flang/include/flang/Runtime/support.h
@@ -49,6 +49,14 @@ void RTDECL(CopyAndUpdateDescriptor)(Descriptor &to, const Descriptor &from,
const typeInfo::DerivedType *newDynamicType,
ISO::CFI_attribute_t newAttribute, enum LowerBoundModifier newLowerBounds);
+// Retrieve the base_addr from Descriptor
+void *RTDECL(DescriptorGetBaseAddress)(const Descriptor &desc,
+ const char *sourceFile = nullptr, int sourceLine = 0);
+
+// Retrieve the totalSizeInBytes of data from Descriptor
+std::size_t RTDECL(DescriptorGetDataSizeInBytes)(const Descriptor &desc,
+ const char *sourceFile = nullptr, int sourceLine = 0);
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_SUPPORT_H_
diff --git a/flang/lib/Optimizer/Builder/Runtime/Support.cpp b/flang/lib/Optimizer/Builder/Runtime/Support.cpp
index d0d48ad718da4..12994b596df4b 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Support.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Support.cpp
@@ -54,3 +54,24 @@ mlir::Value fir::runtime::genIsAssumedSize(fir::FirOpBuilder &builder,
auto args = fir::runtime::createArguments(builder, loc, fTy, box);
return fir::CallOp::create(builder, loc, func, args).getResult(0);
}
+
+mlir::Value fir::runtime::genDescriptorGetBaseAddress(
+ fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value desc,
+ mlir::Value sourceFile, mlir::Value sourceLine) {
+ mlir::func::FuncOp baseAddrFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(DescriptorGetBaseAddress)>(loc,
+ builder);
+ llvm::SmallVector<mlir::Value> args{desc, sourceFile, sourceLine};
+ return fir::CallOp::create(builder, loc, baseAddrFunc, args).getResult(0);
+}
+
+mlir::Value fir::runtime::genDescriptorGetDataSizeInBytes(
+ fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value desc,
+ mlir::Value sourceFile, mlir::Value sourceLine) {
+ mlir::func::FuncOp getDataSizeInBytesFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(DescriptorGetDataSizeInBytes)>(
+ loc, builder);
+ llvm::SmallVector<mlir::Value> args{desc, sourceFile, sourceLine};
+ return fir::CallOp::create(builder, loc, getDataSizeInBytesFunc, args)
+ .getResult(0);
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the use case for that? We have an operation for the base addr already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please answer Valentin's questions before we proceed with this PR.
Hi @clementval This PR is pre-requisite for #140523 |
There is a There is a pretty similar use case with the CUDA Fortran data transfer on assignment and the runtime entry point is taking descriptor directly. I would suggest to do the same. |
For PR #140523, fortran runtime I have looked at CUDA flang-rt APIs which internally call cuda runtime calls like below. """ |
I don't think that adding OpenMP dependencies in flang runtime is a good idea. By the way the box address can be retrieved without runtime call. |
} | ||
} | ||
|
||
void *RTDEF(DescriptorGetBaseAddress)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As Valentin said, there is fir.box_addr
operation that allows taking the base address from a descriptor.
return baseAddr; | ||
} | ||
|
||
std::size_t RTDEF(DescriptorGetDataSizeInBytes)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is fir.box_total_elements
(probably not implemented end-to-end right now) and fir.box_elesize
, which can be used to compute the data size in bytes (of course, assuming that the data is contiguous).
So there are existing operations that should allow you to get all the data for omp_target_memcpy
invocation and insert it in the compiler generated code rather than doing it in the Fortran runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @clementval and @vzakhari for feedback. Will check if these fir ops work in our scenario e2e. Closing the pull request for now.
This PR adds below APIs to flang-rt:
DescriptorGetBaseAddress to retrive base_addr from Descriptor
DescriptorGetDataSizeInBytes to retrive the total Size in bytes of data from Descriptor.