Skip to content

Commit ef8d794

Browse files
committed
Add more error checking
1 parent 6ec68e7 commit ef8d794

File tree

7 files changed

+195
-84
lines changed

7 files changed

+195
-84
lines changed

offload/include/PerThreadTable.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define OFFLOAD_PERTHREADTABLE_H
1515

1616
#include <list>
17+
#include <llvm/Support/Error.h>
1718
#include <memory>
1819
#include <mutex>
1920
#include <type_traits>
@@ -204,6 +205,24 @@ template <typename ContainerType, typename ObjectType> struct PerThreadTable {
204205
}
205206
ThreadDataList.clear();
206207
}
208+
209+
template <class F> llvm::Error deinit(F f) {
210+
std::lock_guard<std::mutex> Lock(Mtx);
211+
for (auto ThData : ThreadDataList) {
212+
if (!ThData->ThEntry || ThData->NElements == 0)
213+
continue;
214+
for (auto &Obj : *ThData->ThEntry) {
215+
if constexpr (is_associative<ContainerType>::value) {
216+
if (auto Err = f(Obj.second))
217+
return Err;
218+
} else {
219+
if (auto Err = f(Obj))
220+
return Err;
221+
}
222+
}
223+
}
224+
return llvm::Error::success();
225+
}
207226
};
208227

209228
template <typename T, typename = std::void_t<>> struct ContainerValueType {

offload/plugins-nextgen/level_zero/include/L0Context.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ class L0ContextTLSTy {
2727
auto &getStagingBuffer() { return StagingBuffer; }
2828
const auto &getStagingBuffer() const { return StagingBuffer; }
2929

30-
void clear() { StagingBuffer.clear(); }
30+
Error deinit() { return StagingBuffer.clear(); }
3131
};
3232

3333
struct L0ContextTLSTableTy
3434
: public PerThreadContainer<
3535
std::unordered_map<ze_context_handle_t, L0ContextTLSTy>> {
36-
void clear() {
37-
PerThreadTable::clear([](L0ContextTLSTy &Entry) { Entry.clear(); });
36+
Error deinit() {
37+
return PerThreadTable::deinit(
38+
[](L0ContextTLSTy &Entry) -> auto { return Entry.deinit(); });
3839
}
3940
};
4041

offload/plugins-nextgen/level_zero/include/L0Device.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -423,55 +423,55 @@ class L0DeviceTy final : public GenericDeviceTy {
423423

424424
// Command queues related functions
425425
/// Create a command list with given ordinal and flags
426-
ze_command_list_handle_t createCmdList(ze_context_handle_t Context,
426+
Expected<ze_command_list_handle_t> createCmdList(ze_context_handle_t Context,
427427
ze_device_handle_t Device,
428428
uint32_t Ordinal,
429429
ze_command_list_flags_t Flags,
430430
const std::string_view DeviceIdStr);
431431

432432
/// Create a command list with default flags
433-
ze_command_list_handle_t createCmdList(ze_context_handle_t Context,
433+
Expected<ze_command_list_handle_t> createCmdList(ze_context_handle_t Context,
434434
ze_device_handle_t Device,
435435
uint32_t Ordinal,
436436
const std::string_view DeviceIdStr);
437437

438-
ze_command_list_handle_t getCmdList();
438+
Expected<ze_command_list_handle_t> getCmdList();
439439

440440
/// Create a command queue with given ordinal and flags
441-
ze_command_queue_handle_t createCmdQueue(ze_context_handle_t Context,
441+
Expected<ze_command_queue_handle_t> createCmdQueue(ze_context_handle_t Context,
442442
ze_device_handle_t Device,
443443
uint32_t Ordinal, uint32_t Index,
444444
ze_command_queue_flags_t Flags,
445445
const std::string_view DeviceIdStr);
446446

447447
/// Create a command queue with default flags
448-
ze_command_queue_handle_t createCmdQueue(ze_context_handle_t Context,
448+
Expected<ze_command_queue_handle_t> createCmdQueue(ze_context_handle_t Context,
449449
ze_device_handle_t Device,
450450
uint32_t Ordinal, uint32_t Index,
451451
const std::string_view DeviceIdStr,
452452
bool InOrder = false);
453453

454454
/// Create a new command queue for the given OpenMP device ID
455-
ze_command_queue_handle_t createCommandQueue(bool InOrder = false);
455+
Expected<ze_command_queue_handle_t> createCommandQueue(bool InOrder = false);
456456

457457
/// Create an immediate command list
458-
ze_command_list_handle_t createImmCmdList(uint32_t Ordinal, uint32_t Index,
458+
Expected<ze_command_list_handle_t> createImmCmdList(uint32_t Ordinal, uint32_t Index,
459459
bool InOrder = false);
460460

461461
/// Create an immediate command list for computing
462-
ze_command_list_handle_t createImmCmdList(bool InOrder = false) {
462+
Expected<ze_command_list_handle_t> createImmCmdList(bool InOrder = false) {
463463
return createImmCmdList(getComputeEngine(), getComputeIndex(), InOrder);
464464
}
465465

466466
/// Create an immediate command list for copying
467-
ze_command_list_handle_t createImmCopyCmdList();
468-
ze_command_queue_handle_t getCmdQueue();
469-
ze_command_list_handle_t getCopyCmdList();
470-
ze_command_queue_handle_t getCopyCmdQueue();
471-
ze_command_list_handle_t getLinkCopyCmdList();
472-
ze_command_queue_handle_t getLinkCopyCmdQueue();
473-
ze_command_list_handle_t getImmCmdList();
474-
ze_command_list_handle_t getImmCopyCmdList();
467+
Expected<ze_command_list_handle_t> createImmCopyCmdList();
468+
Expected<ze_command_queue_handle_t> getCmdQueue();
469+
Expected<ze_command_list_handle_t> getCopyCmdList();
470+
Expected<ze_command_queue_handle_t> getCopyCmdQueue();
471+
Expected<ze_command_list_handle_t> getLinkCopyCmdList();
472+
Expected<ze_command_queue_handle_t> getLinkCopyCmdQueue();
473+
Expected<ze_command_list_handle_t> getImmCmdList();
474+
Expected<ze_command_list_handle_t> getImmCopyCmdList();
475475

476476
/// Enqueue copy command
477477
Error enqueueMemCopy(void *Dst, const void *Src, size_t Size,

offload/plugins-nextgen/level_zero/include/L0Memory.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -502,13 +502,13 @@ class StagingBufferTy {
502502
/// Next buffer location in the buffers
503503
size_t Offset = 0;
504504

505-
void *addBuffers() {
505+
Expected<void *> addBuffers() {
506506
ze_host_mem_alloc_desc_t AllocDesc{ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
507507
nullptr, 0};
508508
void *Ret = nullptr;
509509
size_t AllocSize = Size * Count;
510-
CALL_ZE_RET_NULL(zeMemAllocHost, Context, &AllocDesc, AllocSize,
511-
L0DefaultAlignment, &Ret);
510+
CALL_ZE_RET_ERROR(zeMemAllocHost, Context, &AllocDesc, AllocSize,
511+
L0DefaultAlignment, &Ret);
512512
Buffers.push_back(Ret);
513513
return Ret;
514514
}
@@ -522,12 +522,13 @@ class StagingBufferTy {
522522

523523
~StagingBufferTy() {}
524524

525-
void clear() {
525+
Error clear() {
526526
ze_result_t Rc;
527527
(void)Rc; // GCC build compiler thinks Rc is unused for some reason.
528528
for (auto Ptr : Buffers)
529-
CALL_ZE(Rc, zeMemFree, Context, Ptr);
529+
CALL_ZE_RET_ERROR(zeMemFree, Context, Ptr);
530530
Context = nullptr;
531+
return Plugin::success();
531532
}
532533

533534
bool initialized() const { return Context != nullptr; }
@@ -541,23 +542,26 @@ class StagingBufferTy {
541542
void reset() { Offset = 0; }
542543

543544
/// Always return the first buffer
544-
void *get() {
545+
Expected<void *> get() {
545546
if (Size == 0 || Count == 0)
546547
return nullptr;
547548
return Buffers.empty() ? addBuffers() : Buffers.front();
548549
}
549550

550551
/// Return the next available buffer
551-
void *getNext() {
552+
Expected<void *> getNext() {
552553
void *Ret = nullptr;
553554
if (Size == 0 || Count == 0)
554555
return Ret;
555556

556557
size_t AllocSize = Size * Count;
557558
bool NeedToGrow = Buffers.empty() || Offset >= Buffers.size() * AllocSize;
558-
if (NeedToGrow)
559-
Ret = addBuffers();
560-
else
559+
if (NeedToGrow) {
560+
auto PtrOrErr = addBuffers();
561+
if (!PtrOrErr)
562+
return PtrOrErr.takeError();
563+
Ret = *PtrOrErr;
564+
} else
561565
Ret = reinterpret_cast<void *>(
562566
reinterpret_cast<uintptr_t>(Buffers.back()) + (Offset % AllocSize));
563567

@@ -569,7 +573,7 @@ class StagingBufferTy {
569573
}
570574

571575
/// Return either a fixed buffer or next buffer
572-
void *get(bool Next) { return Next ? getNext() : get(); }
576+
Expected<void *> get(bool Next) { return Next ? getNext() : get(); }
573577
};
574578

575579
} // namespace llvm::omp::target::plugin

0 commit comments

Comments
 (0)