@@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
4545 UR_CALL (DI->allocShadowMemory (Context));
4646 }
4747 CI->DeviceList .emplace_back (hDevice);
48- CI->AllocInfosMap [hDevice];
4948 }
5049 return UR_RESULT_SUCCESS;
5150}
@@ -517,6 +516,12 @@ ur_result_t urMemBufferCreate(
517516 UR_CALL (pMemBuffer->getHandle (hDevice, Handle));
518517 UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
519518 InternalQueue, true , Handle, Host, size, 0 , nullptr , nullptr ));
519+
520+ // Update shadow memory
521+ std::shared_ptr<DeviceInfo> DeviceInfo =
522+ getMsanInterceptor ()->getDeviceInfo (hDevice);
523+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (
524+ InternalQueue, (uptr)Handle, size, 0 ));
520525 }
521526 }
522527
@@ -732,10 +737,25 @@ ur_result_t urEnqueueMemBufferWrite(
732737 if (auto MemBuffer = getMsanInterceptor ()->getMemBuffer (hBuffer)) {
733738 ur_device_handle_t Device = GetDevice (hQueue);
734739 char *pDst = nullptr ;
740+ ur_event_handle_t Events[2 ];
735741 UR_CALL (MemBuffer->getHandle (Device, pDst));
736742 UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
737743 hQueue, blockingWrite, pDst + offset, pSrc, size,
738- numEventsInWaitList, phEventWaitList, phEvent));
744+ numEventsInWaitList, phEventWaitList, &Events[0 ]));
745+
746+ // Update shadow memory
747+ std::shared_ptr<DeviceInfo> DeviceInfo =
748+ getMsanInterceptor ()->getDeviceInfo (Device);
749+ const char Val = 0 ;
750+ uptr ShadowAddr = DeviceInfo->Shadow ->MemToShadow ((uptr)pDst + offset);
751+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill (
752+ hQueue, (void *)ShadowAddr, 1 , &Val, size, numEventsInWaitList,
753+ phEventWaitList, &Events[1 ]));
754+
755+ if (phEvent) {
756+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
757+ hQueue, 2 , Events, phEvent));
758+ }
739759 } else {
740760 UR_CALL (pfnMemBufferWrite (hQueue, hBuffer, blockingWrite, offset, size,
741761 pSrc, numEventsInWaitList, phEventWaitList,
@@ -895,15 +915,32 @@ ur_result_t urEnqueueMemBufferCopy(
895915
896916 if (SrcBuffer && DstBuffer) {
897917 ur_device_handle_t Device = GetDevice (hQueue);
918+ std::shared_ptr<DeviceInfo> DeviceInfo =
919+ getMsanInterceptor ()->getDeviceInfo (Device);
898920 char *SrcHandle = nullptr ;
899921 UR_CALL (SrcBuffer->getHandle (Device, SrcHandle));
900922
901923 char *DstHandle = nullptr ;
902924 UR_CALL (DstBuffer->getHandle (Device, DstHandle));
903925
926+ ur_event_handle_t Events[2 ];
904927 UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
905928 hQueue, false , DstHandle + dstOffset, SrcHandle + srcOffset, size,
906- numEventsInWaitList, phEventWaitList, phEvent));
929+ numEventsInWaitList, phEventWaitList, &Events[0 ]));
930+
931+ // Update shadow memory
932+ uptr DstShadowAddr =
933+ DeviceInfo->Shadow ->MemToShadow ((uptr)DstHandle + dstOffset);
934+ uptr SrcShadowAddr =
935+ DeviceInfo->Shadow ->MemToShadow ((uptr)SrcHandle + srcOffset);
936+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
937+ hQueue, false , (void *)DstShadowAddr, (void *)SrcShadowAddr, size,
938+ numEventsInWaitList, phEventWaitList, &Events[1 ]));
939+
940+ if (phEvent) {
941+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
942+ hQueue, 2 , Events, phEvent));
943+ }
907944 } else {
908945 UR_CALL (pfnMemBufferCopy (hQueue, hBufferSrc, hBufferDst, srcOffset,
909946 dstOffset, size, numEventsInWaitList,
@@ -1002,11 +1039,27 @@ ur_result_t urEnqueueMemBufferFill(
10021039
10031040 if (auto MemBuffer = getMsanInterceptor ()->getMemBuffer (hBuffer)) {
10041041 char *Handle = nullptr ;
1042+ ur_event_handle_t Events[2 ];
10051043 ur_device_handle_t Device = GetDevice (hQueue);
10061044 UR_CALL (MemBuffer->getHandle (Device, Handle));
10071045 UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill (
10081046 hQueue, Handle + offset, patternSize, pPattern, size,
1009- numEventsInWaitList, phEventWaitList, phEvent));
1047+ numEventsInWaitList, phEventWaitList, &Events[0 ]));
1048+
1049+ // Update shadow memory
1050+ std::shared_ptr<DeviceInfo> DeviceInfo =
1051+ getMsanInterceptor ()->getDeviceInfo (Device);
1052+ const char Val = 0 ;
1053+ uptr ShadowAddr =
1054+ DeviceInfo->Shadow ->MemToShadow ((uptr)Handle + offset);
1055+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill (
1056+ hQueue, (void *)ShadowAddr, 1 , &Val, size, numEventsInWaitList,
1057+ phEventWaitList, &Events[1 ]));
1058+
1059+ if (phEvent) {
1060+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1061+ hQueue, 2 , Events, phEvent));
1062+ }
10101063 } else {
10111064 UR_CALL (pfnMemBufferFill (hQueue, hBuffer, pPattern, patternSize, offset,
10121065 size, numEventsInWaitList, phEventWaitList,
0 commit comments