Skip to content

Commit a1acadf

Browse files
Add non-fatal tryLoadBinary and tryGetKernel methods
1 parent e935a32 commit a1acadf

File tree

2 files changed

+107
-47
lines changed

2 files changed

+107
-47
lines changed

offload/unittests/Conformance/include/mathtest/DeviceContext.hpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
#include "llvm/ADT/SetVector.h"
2727
#include "llvm/ADT/StringRef.h"
28+
#include "llvm/Support/Error.h"
2829

2930
#include <cassert>
3031
#include <cstddef>
3132
#include <memory>
33+
#include <optional>
3234
#include <tuple>
3335
#include <type_traits>
3436
#include <utility>
@@ -65,27 +67,47 @@ class DeviceContext {
6567
return ManagedBuffer<T>(TypedAddress, Size);
6668
}
6769

68-
[[nodiscard]] std::shared_ptr<DeviceImage>
69-
loadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName,
70-
llvm::StringRef Extension) const;
71-
7270
[[nodiscard]] std::shared_ptr<DeviceImage>
7371
loadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName) const;
7472

73+
[[nodiscard]] std::optional<std::shared_ptr<DeviceImage>>
74+
tryLoadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName) const;
75+
7576
template <typename KernelSignature>
7677
DeviceKernel<KernelSignature>
7778
getKernel(const std::shared_ptr<DeviceImage> &Image,
7879
llvm::StringRef KernelName) const noexcept {
7980
assert(Image && "Image provided to getKernel is null");
8081

81-
if (Image->DeviceHandle != this->DeviceHandle)
82+
if (Image->DeviceHandle != DeviceHandle)
8283
FATAL_ERROR("Image provided to getKernel was created for a different "
8384
"device");
8485

85-
ol_symbol_handle_t KernelHandle = nullptr;
86-
getKernelImpl(Image->Handle, KernelName, &KernelHandle);
86+
auto KernelHandleOrErr = getKernelImpl(Image->Handle, KernelName);
87+
88+
if (auto Err = KernelHandleOrErr.takeError())
89+
FATAL_ERROR(llvm::toString(std::move(Err)));
90+
91+
return DeviceKernel<KernelSignature>(Image, *KernelHandleOrErr);
92+
}
93+
94+
template <typename KernelSignature>
95+
[[nodiscard]] std::optional<DeviceKernel<KernelSignature>>
96+
tryGetKernel(const std::shared_ptr<DeviceImage> &Image,
97+
llvm::StringRef KernelName) const noexcept {
98+
assert(Image && "Image provided to getKernel is null");
99+
100+
if (Image->DeviceHandle != DeviceHandle)
101+
return std::nullopt;
102+
103+
auto KernelHandleOrErr = getKernelImpl(Image->Handle, KernelName);
87104

88-
return DeviceKernel<KernelSignature>(Image, KernelHandle);
105+
if (auto Err = KernelHandleOrErr.takeError()) {
106+
llvm::consumeError(std::move(Err));
107+
return std::nullopt;
108+
}
109+
110+
return DeviceKernel<KernelSignature>(Image, *KernelHandleOrErr);
89111
}
90112

91113
template <typename KernelSignature, typename... ArgTypes>
@@ -117,14 +139,17 @@ class DeviceContext {
117139
}
118140
}
119141

120-
[[nodiscard]] llvm::StringRef getName() const;
142+
[[nodiscard]] llvm::StringRef getName() const noexcept;
121143

122-
[[nodiscard]] llvm::StringRef getPlatform() const;
144+
[[nodiscard]] llvm::StringRef getPlatform() const noexcept;
123145

124146
private:
125-
void getKernelImpl(ol_program_handle_t ProgramHandle,
126-
llvm::StringRef KernelName,
127-
ol_symbol_handle_t *KernelHandle) const noexcept;
147+
[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
148+
loadBinaryImpl(llvm::StringRef Directory, llvm::StringRef BinaryName) const;
149+
150+
[[nodiscard]] llvm::Expected<ol_symbol_handle_t>
151+
getKernelImpl(ol_program_handle_t ProgramHandle,
152+
llvm::StringRef KernelName) const noexcept;
128153

129154
void launchKernelImpl(ol_symbol_handle_t KernelHandle, const Dim &NumGroups,
130155
const Dim &GroupSize, const void *KernelArgs,

offload/unittests/Conformance/lib/DeviceContext.cpp

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/StringExtras.h"
2222
#include "llvm/ADT/StringRef.h"
2323
#include "llvm/ADT/Twine.h"
24+
#include "llvm/Support/Error.h"
2425
#include "llvm/Support/ErrorHandling.h"
2526
#include "llvm/Support/ErrorOr.h"
2627
#include "llvm/Support/MemoryBuffer.h"
@@ -198,38 +199,55 @@ DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)
198199
}
199200
}
200201

201-
if (!FoundGlobalDeviceId.has_value())
202+
if (!FoundGlobalDeviceId)
202203
FATAL_ERROR("Invalid DeviceId: " + llvm::Twine(DeviceId) +
203204
", but the number of available devices on '" + Platform +
204205
"' is " + llvm::Twine(MatchCount));
205206

206-
GlobalDeviceId = FoundGlobalDeviceId.value();
207+
GlobalDeviceId = *FoundGlobalDeviceId;
207208
DeviceHandle = Devices[GlobalDeviceId].Handle;
208209
}
209210

210-
[[nodiscard]] std::shared_ptr<DeviceImage>
211-
DeviceContext::loadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName,
212-
llvm::StringRef Extension) const {
211+
[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
212+
DeviceContext::loadBinaryImpl(llvm::StringRef Directory,
213+
llvm::StringRef BinaryName) const {
214+
auto Backend = getDevices()[GlobalDeviceId].Backend;
215+
llvm::StringRef Extension;
216+
217+
switch (Backend) {
218+
case OL_PLATFORM_BACKEND_AMDGPU:
219+
Extension = ".amdgpu.bin";
220+
break;
221+
case OL_PLATFORM_BACKEND_CUDA:
222+
Extension = ".nvptx64.bin";
223+
break;
224+
default:
225+
llvm_unreachable("Unsupported backend to infer binary extension");
226+
}
227+
213228
llvm::SmallString<128> FullPath(Directory);
214229
llvm::sys::path::append(FullPath, llvm::Twine(BinaryName) + Extension);
215230

216-
// For simplicity, this implementation intentionally reads the binary from
217-
// disk on every call.
218-
//
219-
// Other use cases could benefit from a global, thread-safe cache to avoid
220-
// redundant file I/O and GPU program creation.
221-
222231
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
223232
llvm::MemoryBuffer::getFile(FullPath);
233+
224234
if (std::error_code ErrorCode = FileOrErr.getError())
225-
FATAL_ERROR(llvm::Twine("Failed to read device binary file '") + FullPath +
226-
"': " + ErrorCode.message());
235+
return llvm::errorCodeToError(ErrorCode);
227236

228237
std::unique_ptr<llvm::MemoryBuffer> &BinaryData = *FileOrErr;
229238

230239
ol_program_handle_t ProgramHandle = nullptr;
231-
OL_CHECK(olCreateProgram(DeviceHandle, BinaryData->getBufferStart(),
232-
BinaryData->getBufferSize(), &ProgramHandle));
240+
const ol_result_t OlResult =
241+
olCreateProgram(DeviceHandle, BinaryData->getBufferStart(),
242+
BinaryData->getBufferSize(), &ProgramHandle);
243+
244+
if (OlResult != OL_SUCCESS) {
245+
llvm::StringRef Details =
246+
OlResult->Details ? OlResult->Details : "No details provided";
247+
248+
return llvm::createStringError(llvm::Twine(Details) + " (Code " +
249+
llvm::Twine(OlResult->Code) + ")");
250+
}
233251

234252
return std::shared_ptr<DeviceImage>(
235253
new DeviceImage(DeviceHandle, ProgramHandle));
@@ -238,29 +256,46 @@ DeviceContext::loadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName,
238256
[[nodiscard]] std::shared_ptr<DeviceImage>
239257
DeviceContext::loadBinary(llvm::StringRef Directory,
240258
llvm::StringRef BinaryName) const {
241-
auto Backend = getDevices()[GlobalDeviceId].Backend;
242-
llvm::StringRef Extension;
259+
auto ImageOrErr = loadBinaryImpl(Directory, BinaryName);
243260

244-
switch (Backend) {
245-
case OL_PLATFORM_BACKEND_AMDGPU:
246-
Extension = ".amdgpu.bin";
247-
break;
248-
case OL_PLATFORM_BACKEND_CUDA:
249-
Extension = ".nvptx64.bin";
250-
break;
251-
default:
252-
llvm_unreachable("Unsupported backend to infer binary extension");
261+
if (auto Err = ImageOrErr.takeError())
262+
FATAL_ERROR(llvm::toString(std::move(Err)));
263+
264+
return std::move(*ImageOrErr);
265+
}
266+
267+
[[nodiscard]] std::optional<std::shared_ptr<DeviceImage>>
268+
DeviceContext::tryLoadBinary(llvm::StringRef Directory,
269+
llvm::StringRef BinaryName) const {
270+
auto ImageOrErr = loadBinaryImpl(Directory, BinaryName);
271+
272+
if (auto Err = ImageOrErr.takeError()) {
273+
llvm::consumeError(std::move(Err));
274+
return std::nullopt;
253275
}
254276

255-
return loadBinary(Directory, BinaryName, Extension);
277+
return std::move(*ImageOrErr);
256278
}
257279

258-
void DeviceContext::getKernelImpl(
259-
ol_program_handle_t ProgramHandle, llvm::StringRef KernelName,
260-
ol_symbol_handle_t *KernelHandle) const noexcept {
280+
[[nodiscard]] llvm::Expected<ol_symbol_handle_t>
281+
DeviceContext::getKernelImpl(ol_program_handle_t ProgramHandle,
282+
llvm::StringRef KernelName) const noexcept {
283+
ol_symbol_handle_t KernelHandle = nullptr;
261284
llvm::SmallString<32> KernelNameBuffer(KernelName);
262-
OL_CHECK(olGetSymbol(ProgramHandle, KernelNameBuffer.c_str(),
263-
OL_SYMBOL_KIND_KERNEL, KernelHandle));
285+
286+
const ol_result_t OlResult =
287+
olGetSymbol(ProgramHandle, KernelNameBuffer.c_str(),
288+
OL_SYMBOL_KIND_KERNEL, &KernelHandle);
289+
290+
if (OlResult != OL_SUCCESS) {
291+
llvm::StringRef Details =
292+
OlResult->Details ? OlResult->Details : "No details provided";
293+
294+
return llvm::createStringError(llvm::Twine(Details) + " (Code " +
295+
llvm::Twine(OlResult->Code) + ")");
296+
}
297+
298+
return KernelHandle;
264299
}
265300

266301
void DeviceContext::launchKernelImpl(
@@ -277,10 +312,10 @@ void DeviceContext::launchKernelImpl(
277312
KernelArgsSize, &LaunchArgs, nullptr));
278313
}
279314

280-
[[nodiscard]] llvm::StringRef DeviceContext::getName() const {
315+
[[nodiscard]] llvm::StringRef DeviceContext::getName() const noexcept {
281316
return getDevices()[GlobalDeviceId].Name;
282317
}
283318

284-
[[nodiscard]] llvm::StringRef DeviceContext::getPlatform() const {
319+
[[nodiscard]] llvm::StringRef DeviceContext::getPlatform() const noexcept {
285320
return getDevices()[GlobalDeviceId].Platform;
286321
}

0 commit comments

Comments
 (0)