Skip to content

Commit 9283535

Browse files
committed
fix cpu
1 parent 2469522 commit 9283535

File tree

1 file changed

+44
-19
lines changed

1 file changed

+44
-19
lines changed

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,8 +1307,9 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
13071307
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
13081308
getContext()->logger.debug("==== urEnqueueUSMFill");
13091309

1310+
ur_event_handle_t hEvent = nullptr;
13101311
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
1311-
numEventsInWaitList, phEventWaitList, phEvent));
1312+
numEventsInWaitList, phEventWaitList, &hEvent));
13121313

13131314
const auto Mem = (uptr)pMem;
13141315
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
@@ -1319,8 +1320,13 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
13191320
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
13201321
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);
13211322

1322-
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size,
1323-
phEvent ? 1 : 0, phEvent, phEvent));
1323+
const ur_event_handle_t hEventWait = hEvent;
1324+
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 1,
1325+
&hEventWait, &hEvent));
1326+
}
1327+
1328+
if (phEvent) {
1329+
*phEvent = hEvent;
13241330
}
13251331

13261332
return UR_RESULT_SUCCESS;
@@ -1350,8 +1356,9 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13501356
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
13511357
getContext()->logger.debug("==== pfnUSMMemcpy");
13521358

1359+
ur_event_handle_t hEvent = nullptr;
13531360
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
1354-
numEventsInWaitList, phEventWaitList, phEvent));
1361+
numEventsInWaitList, phEventWaitList, &hEvent));
13551362

13561363
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
13571364
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
@@ -1366,18 +1373,23 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13661373
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
13671374
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
13681375

1376+
const ur_event_handle_t hEventWait = hEvent;
13691377
UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
1370-
(void *)SrcShadow, size, phEvent ? 1 : 0, phEvent,
1371-
phEvent));
1378+
(void *)SrcShadow, size, 1, &hEventWait, &hEvent));
13721379
} else if (DstInfoItOp) {
13731380
auto DstInfo = (*DstInfoItOp)->second;
13741381

13751382
const auto &DeviceInfo =
13761383
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
13771384
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
13781385

1379-
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size,
1380-
phEvent ? 1 : 0, phEvent, phEvent));
1386+
const ur_event_handle_t hEventWait = hEvent;
1387+
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 1,
1388+
&hEventWait, &hEvent));
1389+
}
1390+
1391+
if (phEvent) {
1392+
*phEvent = hEvent;
13811393
}
13821394

13831395
return UR_RESULT_SUCCESS;
@@ -1413,9 +1425,10 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
14131425
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
14141426
getContext()->logger.debug("==== urEnqueueUSMFill2D");
14151427

1428+
ur_event_handle_t hEvent = nullptr;
14161429
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
14171430
height, numEventsInWaitList, phEventWaitList,
1418-
phEvent));
1431+
&hEvent));
14191432

14201433
const auto Mem = (uptr)pMem;
14211434
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
@@ -1427,8 +1440,13 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
14271440
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);
14281441

14291442
const char Pattern = 0;
1443+
const ur_event_handle_t hEventWait = hEvent;
14301444
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
1431-
width, height, phEvent ? 1 : 0, phEvent, phEvent));
1445+
width, height, 1, &hEventWait, &hEvent));
1446+
}
1447+
1448+
if (phEvent) {
1449+
*phEvent = hEvent;
14321450
}
14331451

14341452
return UR_RESULT_SUCCESS;
@@ -1463,11 +1481,12 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
14631481
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
14641482
getContext()->logger.debug("==== pfnUSMMemcpy2D");
14651483

1484+
ur_event_handle_t hEvent = nullptr;
14661485
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
14671486
width, height, numEventsInWaitList, phEventWaitList,
1468-
phEvent));
1487+
&hEvent));
14691488

1470-
auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1489+
const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
14711490
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
14721491
auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst);
14731492

@@ -1477,23 +1496,29 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
14771496

14781497
const auto &DeviceInfo =
14791498
getMsanInterceptor()->getDeviceInfo(SrcInfo->Device);
1480-
auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
1481-
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
1499+
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
1500+
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
14821501

1502+
const ur_event_handle_t hEventWait = hEvent;
14831503
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
1484-
(void *)SrcShadow, srcPitch, width, height,
1485-
phEvent ? 1 : 0, phEvent, phEvent));
1504+
(void *)SrcShadow, srcPitch, width, height, 1,
1505+
&hEventWait, &hEvent));
14861506
} else if (DstInfoItOp) {
14871507
auto DstInfo = (*DstInfoItOp)->second;
14881508

14891509
const auto &DeviceInfo =
14901510
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
1491-
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
1511+
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);
14921512

14931513
const char Pattern = 0;
1514+
const ur_event_handle_t hEventWait = hEvent;
14941515
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
1495-
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height,
1496-
phEvent ? 1 : 0, phEvent, phEvent));
1516+
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 1,
1517+
&hEventWait, &hEvent));
1518+
}
1519+
1520+
if (phEvent) {
1521+
*phEvent = hEvent;
14971522
}
14981523

14991524
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)