Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions flang-rt/lib/runtime/support.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,30 @@ void RTDEF(CopyAndUpdateDescriptor)(Descriptor &to, const Descriptor &from,
}
}

void *RTDEF(DescriptorGetBaseAddress)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Valentin said, there is fir.box_addr operation that allows taking the base address from a descriptor.

const Descriptor &desc, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
void *baseAddr = desc.raw().base_addr;
if (!baseAddr) {
terminator.Crash("Could not retrieve Descriptor's base address");
}
return baseAddr;
}

std::size_t RTDEF(DescriptorGetDataSizeInBytes)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is fir.box_total_elements (probably not implemented end-to-end right now) and fir.box_elesize, which can be used to compute the data size in bytes (of course, assuming that the data is contiguous).

So there are existing operations that should allow you to get all the data for omp_target_memcpy invocation and insert it in the compiler generated code rather than doing it in the Fortran runtime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @clementval and @vzakhari for feedback. Will check if these fir ops work in our scenario e2e. Closing the pull request for now.

const Descriptor &desc, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
std::size_t descElements{desc.Elements()};
if (!descElements) {
terminator.Crash("Could not retrieve Descriptor's Elements");
}
std::size_t descElementBytes{desc.ElementBytes()};
if (!descElementBytes) {
terminator.Crash("Could not retrieve Descriptor's ElementBytes");
}
return descElements * descElementBytes;
}

RT_EXT_API_GROUP_END
} // extern "C"
} // namespace Fortran::runtime
23 changes: 23 additions & 0 deletions flang-rt/unittests/Runtime/Support.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,26 @@ TEST(IsContiguous, Basic) {
EXPECT_TRUE(RTNAME(IsContiguousUpTo)(section, 1));
EXPECT_FALSE(RTNAME(IsContiguousUpTo)(section, 2));
}

TEST(DescriptorGetBaseAddress, Basic) {
auto array{MakeArray<TypeCategory::Integer, 4>(
std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
void *baseAddr = RTNAME(DescriptorGetBaseAddress)(*array);
EXPECT_NE(baseAddr, nullptr);
EXPECT_EQ(baseAddr, array->raw().base_addr);
}

TEST(DescriptorGetDataSizeInBytes, Basic) {
// Test with a 2x3 integer*4 array
auto int4Array{MakeArray<TypeCategory::Integer, 4>({2, 3})};
EXPECT_EQ(RTNAME(DescriptorGetDataSizeInBytes)(*int4Array),
6 * sizeof(std::int32_t));
// Test with a 1D, 5-element real*8 array
auto real8Array{MakeArray<TypeCategory::Real, 8>({5})};
EXPECT_EQ(
RTNAME(DescriptorGetDataSizeInBytes)(*real8Array), 5 * sizeof(double));
// Test with a scalar logical*1
auto logical1Scalar{MakeArray<TypeCategory::Logical, 1>({})};
EXPECT_EQ(
RTNAME(DescriptorGetDataSizeInBytes)(*logical1Scalar), 1 * sizeof(bool));
}
13 changes: 13 additions & 0 deletions flang/include/flang/Optimizer/Builder/Runtime/Support.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,18 @@ void genCopyAndUpdateDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value genIsAssumedSize(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value box);

/// Generate call to `DescriptorGetBaseAddress` runtime routine.
mlir::Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value desc,
mlir::Value sourceFile,
mlir::Value sourceLine);

/// Generate call to `DescriptorGetDataSizeInBytes` runtime routine.
mlir::Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value desc,
mlir::Value sourceFile,
mlir::Value sourceLine);

} // namespace fir::runtime
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H
8 changes: 8 additions & 0 deletions flang/include/flang/Runtime/support.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ void RTDECL(CopyAndUpdateDescriptor)(Descriptor &to, const Descriptor &from,
const typeInfo::DerivedType *newDynamicType,
ISO::CFI_attribute_t newAttribute, enum LowerBoundModifier newLowerBounds);

// Retrieve the base_addr from Descriptor
void *RTDECL(DescriptorGetBaseAddress)(const Descriptor &desc,
const char *sourceFile = nullptr, int sourceLine = 0);

// Retrieve the totalSizeInBytes of data from Descriptor
std::size_t RTDECL(DescriptorGetDataSizeInBytes)(const Descriptor &desc,
const char *sourceFile = nullptr, int sourceLine = 0);

} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_SUPPORT_H_
21 changes: 21 additions & 0 deletions flang/lib/Optimizer/Builder/Runtime/Support.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,24 @@ mlir::Value fir::runtime::genIsAssumedSize(fir::FirOpBuilder &builder,
auto args = fir::runtime::createArguments(builder, loc, fTy, box);
return fir::CallOp::create(builder, loc, func, args).getResult(0);
}

mlir::Value fir::runtime::genDescriptorGetBaseAddress(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value desc,
mlir::Value sourceFile, mlir::Value sourceLine) {
mlir::func::FuncOp baseAddrFunc =
fir::runtime::getRuntimeFunc<mkRTKey(DescriptorGetBaseAddress)>(loc,
builder);
llvm::SmallVector<mlir::Value> args{desc, sourceFile, sourceLine};
return fir::CallOp::create(builder, loc, baseAddrFunc, args).getResult(0);
}

mlir::Value fir::runtime::genDescriptorGetDataSizeInBytes(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value desc,
mlir::Value sourceFile, mlir::Value sourceLine) {
mlir::func::FuncOp getDataSizeInBytesFunc =
fir::runtime::getRuntimeFunc<mkRTKey(DescriptorGetDataSizeInBytes)>(
loc, builder);
llvm::SmallVector<mlir::Value> args{desc, sourceFile, sourceLine};
return fir::CallOp::create(builder, loc, getDataSizeInBytesFunc, args)
.getResult(0);
}