diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index 66eafedf15..293f3eea7a 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -28,7 +28,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream, return UR_RESULT_SUCCESS; } try { - auto Result = forLatestEvents( + UR_CHECK_ERROR(forLatestEvents( EventWaitList, NumEventsInWaitList, [Stream, Queue](ur_event_handle_t Event) -> ur_result_t { ScopedDevice Active(Queue->getDevice()); @@ -38,17 +38,13 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream, UR_CHECK_ERROR(hipStreamWaitEvent(Stream, Event->get(), 0)); return UR_RESULT_SUCCESS; } - }); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - return UR_RESULT_SUCCESS; + })); } catch (ur_result_t Err) { return Err; } catch (...) { return UR_RESULT_ERROR_UNKNOWN; } + return UR_RESULT_SUCCESS; } // Determine local work sizes that result in uniform work groups. @@ -630,12 +626,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( try { ScopedDevice Active(hQueue->getDevice()); - ur_result_t Result = UR_RESULT_SUCCESS; auto Stream = hQueue->getNextTransferStream(); if (phEventWaitList) { - Result = enqueueEventsWait(hQueue, Stream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList, + phEventWaitList)); } if (phEvent) { @@ -657,12 +652,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( *phEvent = RetImplEvent.release(); } - return Result; } catch (ur_result_t Err) { return Err; } catch (...) { return UR_RESULT_ERROR_UNKNOWN; } + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( @@ -672,7 +667,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - ur_result_t Result = UR_RESULT_SUCCESS; void *SrcPtr = std::get(hBufferSrc->Mem).getVoid(hQueue->getDevice()); void *DstPtr = @@ -682,8 +676,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( try { ScopedDevice Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); - Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); if (phEvent) { RetImplEvent = @@ -692,10 +686,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( UR_CHECK_ERROR(RetImplEvent->start()); } - Result = commonEnqueueMemBufferCopyRect( + UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect( HIPStream, region, &SrcPtr, hipMemoryTypeDevice, srcOrigin, srcRowPitch, srcSlicePitch, &DstPtr, hipMemoryTypeDevice, dstOrigin, dstRowPitch, - dstSlicePitch); + dstSlicePitch)); if (phEvent) { UR_CHECK_ERROR(RetImplEvent->record()); @@ -703,9 +697,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize, @@ -1063,14 +1057,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy( std::get(hImageDst->Mem).getImageType(), UR_RESULT_ERROR_INVALID_MEM_OBJECT); - ur_result_t Result = UR_RESULT_SUCCESS; - try { ScopedDevice Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); if (phEventWaitList) { - Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); } hipArray *SrcArray = @@ -1110,13 +1102,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy( UR_CHECK_ERROR(RetImplEvent->start()); } - Result = commonEnqueueMemImageNDCopy( + UR_CHECK_ERROR(commonEnqueueMemImageNDCopy( HIPStream, ImgType, AdjustedRegion, SrcArray, hipMemoryTypeArray, - SrcOffset, DstArray, hipMemoryTypeArray, DstOffset); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + SrcOffset, DstArray, hipMemoryTypeArray, DstOffset)); if (phEvent) { UR_CHECK_ERROR(RetImplEvent->record()); @@ -1237,7 +1225,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( ur_queue_handle_t hQueue, void *ptr, size_t patternSize, const void *pPattern, size_t size, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - ur_result_t Result = UR_RESULT_SUCCESS; std::unique_ptr EventPtr{nullptr}; try { @@ -1246,8 +1233,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( ur_stream_guard Guard; hipStream_t HIPStream = hQueue->getNextComputeStream( numEventsInWaitList, phEventWaitList, Guard, &StreamToken); - Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); if (phEvent) { EventPtr = std::unique_ptr(ur_event_handle_t_::makeNative( @@ -1274,8 +1261,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( break; default: - Result = commonMemSetLargePattern(HIPStream, patternSize, size, pPattern, - reinterpret_cast(ptr)); + UR_CHECK_ERROR( + commonMemSetLargePattern(HIPStream, patternSize, size, pPattern, + reinterpret_cast(ptr))); break; } @@ -1284,25 +1272,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( *phEvent = EventPtr.release(); } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc, size_t size, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - ur_result_t Result = UR_RESULT_SUCCESS; - std::unique_ptr EventPtr{nullptr}; try { ScopedDevice Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); - Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); if (phEvent) { EventPtr = std::unique_ptr(ur_event_handle_t_::makeNative( @@ -1321,9 +1307,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( *phEvent = EventPtr.release(); } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( @@ -1345,13 +1331,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE); #endif - ur_result_t Result = UR_RESULT_SUCCESS; - try { ScopedDevice Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); - Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); std::unique_ptr EventPtr{nullptr}; @@ -1399,10 +1383,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream)); releaseEvent(); } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } /// USM: memadvise API to govern behavior of automatic migration mechanisms @@ -1521,6 +1505,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, UR_RESULT_SUCCESS); return UR_RESULT_ERROR_ADAPTER_SPECIFIC; } + UR_CHECK_ERROR(Result); } releaseEvent(); @@ -1558,13 +1543,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( const void *pSrc, size_t srcPitch, size_t width, size_t height, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - ur_result_t Result = UR_RESULT_SUCCESS; - try { ScopedDevice Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); - Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList); + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); std::unique_ptr RetImplEvent{nullptr}; if (phEvent) { @@ -1668,10 +1651,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( UR_CHECK_ERROR(hipStreamSynchronize(HIPStream)); } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } namespace {