Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 36 additions & 53 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
return UR_RESULT_SUCCESS;
}
try {
auto Result = forLatestEvents(
UR_CHECK_ERROR(forLatestEvents(
EventWaitList, NumEventsInWaitList,
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
ScopedDevice Active(Queue->getDevice());
Expand All @@ -38,17 +38,13 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
UR_CHECK_ERROR(hipStreamWaitEvent(Stream, Event->get(), 0));
return UR_RESULT_SUCCESS;
}
});

if (Result != UR_RESULT_SUCCESS) {
return Result;
}
return UR_RESULT_SUCCESS;
}));
} catch (ur_result_t Err) {
return Err;
} catch (...) {
return UR_RESULT_ERROR_UNKNOWN;
}
return UR_RESULT_SUCCESS;
}

// Determine local work sizes that result in uniform work groups.
Expand Down Expand Up @@ -630,12 +626,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(

try {
ScopedDevice Active(hQueue->getDevice());
ur_result_t Result = UR_RESULT_SUCCESS;
auto Stream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
phEventWaitList));
}

if (phEvent) {
Expand All @@ -657,12 +652,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
*phEvent = RetImplEvent.release();
}

return Result;
} catch (ur_result_t Err) {
return Err;
} catch (...) {
return UR_RESULT_ERROR_UNKNOWN;
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
Expand All @@ -672,7 +667,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
ur_result_t Result = UR_RESULT_SUCCESS;
void *SrcPtr =
std::get<BufferMem>(hBufferSrc->Mem).getVoid(hQueue->getDevice());
void *DstPtr =
Expand All @@ -682,8 +676,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
try {
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));

if (phEvent) {
RetImplEvent =
Expand All @@ -692,20 +686,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
UR_CHECK_ERROR(RetImplEvent->start());
}

Result = commonEnqueueMemBufferCopyRect(
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
HIPStream, region, &SrcPtr, hipMemoryTypeDevice, srcOrigin, srcRowPitch,
srcSlicePitch, &DstPtr, hipMemoryTypeDevice, dstOrigin, dstRowPitch,
dstSlicePitch);
dstSlicePitch));

if (phEvent) {
UR_CHECK_ERROR(RetImplEvent->record());
*phEvent = RetImplEvent.release();
}

} catch (ur_result_t Err) {
Result = Err;
return Err;
}
return Result;
return UR_RESULT_SUCCESS;
}

static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
Expand Down Expand Up @@ -1063,14 +1057,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
std::get<SurfaceMem>(hImageDst->Mem).getImageType(),
UR_RESULT_ERROR_INVALID_MEM_OBJECT);

ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));
}

hipArray *SrcArray =
Expand Down Expand Up @@ -1110,13 +1102,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
UR_CHECK_ERROR(RetImplEvent->start());
}

Result = commonEnqueueMemImageNDCopy(
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
HIPStream, ImgType, AdjustedRegion, SrcArray, hipMemoryTypeArray,
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset);

if (Result != UR_RESULT_SUCCESS) {
return Result;
}
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset));

if (phEvent) {
UR_CHECK_ERROR(RetImplEvent->record());
Expand Down Expand Up @@ -1237,7 +1225,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
ur_result_t Result = UR_RESULT_SUCCESS;
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
Expand All @@ -1246,8 +1233,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
ur_stream_guard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));
if (phEvent) {
EventPtr =
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
Expand All @@ -1274,8 +1261,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
break;

default:
Result = commonMemSetLargePattern(HIPStream, patternSize, size, pPattern,
reinterpret_cast<hipDeviceptr_t>(ptr));
UR_CHECK_ERROR(
commonMemSetLargePattern(HIPStream, patternSize, size, pPattern,
reinterpret_cast<hipDeviceptr_t>(ptr)));
break;
}

Expand All @@ -1284,25 +1272,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
*phEvent = EventPtr.release();
}
} catch (ur_result_t Err) {
Result = Err;
return Err;
}

return Result;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
size_t size, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
ur_result_t Result = UR_RESULT_SUCCESS;

std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));
if (phEvent) {
EventPtr =
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
Expand All @@ -1321,9 +1307,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
*phEvent = EventPtr.release();
}
} catch (ur_result_t Err) {
Result = Err;
return Err;
}
return Result;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
Expand All @@ -1345,13 +1331,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
#endif

ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));

std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

Expand Down Expand Up @@ -1399,10 +1383,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
releaseEvent();
} catch (ur_result_t Err) {
Result = Err;
return Err;
}

return Result;
return UR_RESULT_SUCCESS;
}

/// USM: memadvise API to govern behavior of automatic migration mechanisms
Expand Down Expand Up @@ -1521,6 +1505,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
UR_CHECK_ERROR(Result);
}

releaseEvent();
Expand Down Expand Up @@ -1558,13 +1543,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
const void *pSrc, size_t srcPitch, size_t width, size_t height,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList));

std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
if (phEvent) {
Expand Down Expand Up @@ -1668,10 +1651,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
UR_CHECK_ERROR(hipStreamSynchronize(HIPStream));
}
} catch (ur_result_t Err) {
Result = Err;
return Err;
}

return Result;
return UR_RESULT_SUCCESS;
}

namespace {
Expand Down
Loading