Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ class DeviceContext {
explicit DeviceContext(llvm::StringRef Platform, std::size_t DeviceId = 0);

template <typename T>
ManagedBuffer<T> createManagedBuffer(std::size_t Size) noexcept {
ManagedBuffer<T> createManagedBuffer(std::size_t Size) const noexcept {
void *UntypedAddress = nullptr;

detail::allocManagedMemory(DeviceHandle, Size * sizeof(T), &UntypedAddress);
T *TypedAddress = static_cast<T *>(UntypedAddress);

return ManagedBuffer<T>(getPlatformHandle(), TypedAddress, Size);
return ManagedBuffer<T>(PlatformHandle, TypedAddress, Size);
}

[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
Expand Down Expand Up @@ -120,9 +120,6 @@ class DeviceContext {

[[nodiscard]] llvm::StringRef getPlatform() const noexcept;

[[nodiscard]] llvm::Expected<ol_platform_handle_t>
getPlatformHandle() noexcept;

private:
[[nodiscard]] llvm::Expected<ol_symbol_handle_t>
getKernelHandle(ol_program_handle_t ProgramHandle,
Expand All @@ -134,7 +131,7 @@ class DeviceContext {

std::size_t GlobalDeviceId;
ol_device_handle_t DeviceHandle;
ol_platform_handle_t PlatformHandle = nullptr;
ol_platform_handle_t PlatformHandle;
};
} // namespace mathtest

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
ManagedBuffer &operator=(const ManagedBuffer &) = delete;

ManagedBuffer(ManagedBuffer &&Other) noexcept
: Address(Other.Address), Size(Other.Size) {
: Platform(Other.Platform), Address(Other.Address), Size(Other.Size) {
Other.Platform = nullptr;
Other.Address = nullptr;
Other.Size = 0;
}
Expand All @@ -59,9 +60,11 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
if (Address)
detail::freeDeviceMemory(Platform, Address);

Platform = Other.Platform;
Address = Other.Address;
Size = Other.Size;

Other.Platform = nullptr;
Other.Address = nullptr;
Other.Size = 0;

Expand Down Expand Up @@ -89,7 +92,7 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
std::size_t Size) noexcept
: Platform(Platform), Address(Address), Size(Size) {}

ol_platform_handle_t Platform;
ol_platform_handle_t Platform = nullptr;
T *Address = nullptr;
std::size_t Size = 0;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class [[nodiscard]] GpuMathTest final {

ResultType run(GeneratorType &Generator,
std::size_t BufferSize = DefaultBufferSize,
uint32_t GroupSize = DefaultGroupSize) noexcept {
uint32_t GroupSize = DefaultGroupSize) const noexcept {
assert(BufferSize > 0 && "Buffer size must be a positive value");
assert(GroupSize > 0 && "Group size must be a positive value");

Expand Down Expand Up @@ -128,7 +128,7 @@ class [[nodiscard]] GpuMathTest final {
return *ExpectedKernel;
}

[[nodiscard]] auto createBuffers(std::size_t BufferSize) {
[[nodiscard]] auto createBuffers(std::size_t BufferSize) const {
auto InBuffersTuple = std::apply(
[&](auto... InTypeIdentities) {
return std::make_tuple(
Expand Down
28 changes: 4 additions & 24 deletions offload/unittests/Conformance/lib/DeviceContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ getPlatformBackend(ol_platform_handle_t PlatformHandle) noexcept {

struct Device {
ol_device_handle_t Handle;
ol_platform_handle_t PlatformHandle;
std::string Name;
std::string Platform;
ol_platform_backend_t Backend;
Expand All @@ -124,7 +125,7 @@ const std::vector<Device> &getDevices() {
auto Platform = getPlatformName(PlatformHandle);

static_cast<std::vector<Device> *>(Data)->push_back(
{DeviceHandle, Name, Platform, Backend});
{DeviceHandle, PlatformHandle, Name, Platform, Backend});
}

return true;
Expand Down Expand Up @@ -175,6 +176,7 @@ DeviceContext::DeviceContext(std::size_t GlobalDeviceId)
llvm::Twine(Devices.size()));

DeviceHandle = Devices[GlobalDeviceId].Handle;
PlatformHandle = Devices[GlobalDeviceId].PlatformHandle;
}

DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)
Expand Down Expand Up @@ -210,6 +212,7 @@ DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)

GlobalDeviceId = *FoundGlobalDeviceId;
DeviceHandle = Devices[GlobalDeviceId].Handle;
PlatformHandle = Devices[GlobalDeviceId].PlatformHandle;
}

[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
Expand Down Expand Up @@ -286,29 +289,6 @@ DeviceContext::getKernelHandle(ol_program_handle_t ProgramHandle,
return Handle;
}

llvm::Expected<ol_platform_handle_t>
DeviceContext::getPlatformHandle() noexcept {
if (!PlatformHandle) {
const ol_result_t OlResult =
olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_PLATFORM,
sizeof(PlatformHandle), &PlatformHandle);

if (OlResult != OL_SUCCESS) {
PlatformHandle = nullptr;
llvm::StringRef Details =
OlResult->Details ? OlResult->Details : "No details provided";

// clang-format off
return llvm::createStringError(
llvm::Twine(Details) +
" (code " + llvm::Twine(OlResult->Code) + ")");
// clang-format on
}
}

return PlatformHandle;
}

void DeviceContext::launchKernelImpl(
ol_symbol_handle_t KernelHandle, uint32_t NumGroups, uint32_t GroupSize,
const void *KernelArgs, std::size_t KernelArgsSize) const noexcept {
Expand Down
Loading