@@ -51,6 +51,48 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
5151 return UR_RESULT_SUCCESS;
5252}
5353
54+ ur_result_t urEnqueueUSMFill2DFallback (ur_queue_handle_t hQueue, void *pMem,
55+ size_t pitch, size_t patternSize,
56+ const void *pPattern, size_t width,
57+ size_t height,
58+ uint32_t numEventsInWaitList,
59+ const ur_event_handle_t *phEventWaitList,
60+ ur_event_handle_t *phEvent) {
61+ ur_result_t Result = getContext ()->urDdiTable .Enqueue .pfnUSMFill2D (
62+ hQueue, pMem, pitch, patternSize, pPattern, width, height,
63+ numEventsInWaitList, phEventWaitList, phEvent);
64+ if (Result == UR_RESULT_SUCCESS ||
65+ Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
66+ return Result;
67+ }
68+
69+ // fallback code
70+ auto pfnUSMFill = getContext ()->urDdiTable .Enqueue .pfnUSMFill ;
71+
72+ std::vector<ur_event_handle_t > WaitEvents (numEventsInWaitList);
73+
74+ for (size_t HeightIndex = 0 ; HeightIndex < height; HeightIndex++) {
75+ ur_event_handle_t Event = nullptr ;
76+
77+ UR_CALL (pfnUSMFill (hQueue, (void *)((char *)pMem + pitch * HeightIndex),
78+ patternSize, pPattern, width, WaitEvents.size (),
79+ WaitEvents.data (), &Event));
80+
81+ WaitEvents.push_back (Event);
82+ }
83+
84+ if (phEvent) {
85+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
86+ hQueue, WaitEvents.size (), WaitEvents.data (), phEvent));
87+ }
88+
89+ for (const auto Event : WaitEvents) {
90+ UR_CALL (getContext ()->urDdiTable .Event .pfnRelease (Event));
91+ }
92+
93+ return UR_RESULT_SUCCESS;
94+ }
95+
5496} // namespace
5597
5698// /////////////////////////////////////////////////////////////////////////////
@@ -1726,11 +1768,6 @@ ur_result_t urEnqueueUSMMemcpy2D(
17261768 {
17271769 auto pfnUSMMemcpy = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy ;
17281770
1729- std::vector<ur_event_handle_t > WaitEvents (numEventsInWaitList);
1730- for (uint32_t i = 0 ; i < numEventsInWaitList; i++) {
1731- WaitEvents[i] = phEventWaitList[i];
1732- }
1733-
17341771 for (size_t HeightIndex = 0 ; HeightIndex < height; HeightIndex++) {
17351772 ur_event_handle_t Event = nullptr ;
17361773 const auto DstOrigin =
@@ -1742,8 +1779,8 @@ ur_result_t urEnqueueUSMMemcpy2D(
17421779 width - 1 ) +
17431780 MSAN_ORIGIN_GRANULARITY;
17441781 pfnUSMMemcpy (hQueue, false , (void *)DstOrigin, (void *)SrcOrigin,
1745- SrcOriginEnd - SrcOrigin, WaitEvents. size () ,
1746- WaitEvents. data () , &Event);
1782+ SrcOriginEnd - SrcOrigin, numEventsInWaitList ,
1783+ phEventWaitList , &Event);
17471784 Events.push_back (Event);
17481785 }
17491786 }
@@ -1756,9 +1793,9 @@ ur_result_t urEnqueueUSMMemcpy2D(
17561793 const auto DstShadow = DstDI->Shadow ->MemToShadow ((uptr)pDst);
17571794 const char Pattern = 0 ;
17581795 ur_event_handle_t Event = nullptr ;
1759- UR_CALL (getContext ()-> urDdiTable . Enqueue . pfnUSMFill2D (
1760- hQueue, ( void *)DstShadow, dstPitch, 1 , &Pattern, width, height, 0 ,
1761- nullptr , &Event));
1796+ UR_CALL (urEnqueueUSMFill2DFallback (hQueue, ( void *)DstShadow, dstPitch, 1 ,
1797+ &Pattern, width, height, 0 , nullptr ,
1798+ &Event));
17621799 Events.push_back (Event);
17631800 }
17641801
@@ -1767,7 +1804,7 @@ ur_result_t urEnqueueUSMMemcpy2D(
17671804 hQueue, Events.size (), Events.data (), phEvent));
17681805 }
17691806
1770- for (const auto & E : Events)
1807+ for (const auto E : Events)
17711808 UR_CALL (getContext ()->urDdiTable .Event .pfnRelease (E));
17721809
17731810 return UR_RESULT_SUCCESS;
0 commit comments