@@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
4545 UR_CALL (DI->allocShadowMemory (Context));
4646 }
4747 CI->DeviceList .emplace_back (hDevice);
48- CI->AllocInfosMap [hDevice];
4948 }
5049 return UR_RESULT_SUCCESS;
5150}
@@ -104,6 +103,17 @@ ur_result_t urUSMDeviceAlloc(
104103 pool, size, ppMem);
105104}
106105
106+ // /////////////////////////////////////////////////////////////////////////////
107+ // / @brief Intercept function for urUSMFree
108+ __urdlllocal ur_result_t UR_APICALL urUSMFree (
109+ ur_context_handle_t hContext, // /< [in] handle of the context object
110+ void *pMem // /< [in] pointer to USM memory object
111+ ) {
112+ getContext ()->logger .debug (" ==== urUSMFree" );
113+
114+ return getMsanInterceptor ()->releaseMemory (hContext, pMem);
115+ }
116+
107117// /////////////////////////////////////////////////////////////////////////////
108118// / @brief Intercept function for urProgramCreateWithIL
109119ur_result_t urProgramCreateWithIL (
@@ -1234,6 +1244,247 @@ ur_result_t urKernelSetArgMemObj(
12341244 return UR_RESULT_SUCCESS;
12351245}
12361246
1247+ // /////////////////////////////////////////////////////////////////////////////
1248+ // / @brief Intercept function for urEnqueueUSMFill
1249+ ur_result_t UR_APICALL urEnqueueUSMFill (
1250+ ur_queue_handle_t hQueue, // /< [in] handle of the queue object
1251+ void *pMem, // /< [in][bounds(0, size)] pointer to USM memory object
1252+ size_t
1253+ patternSize, // /< [in] the size in bytes of the pattern. Must be a power of 2 and less
1254+ // /< than or equal to width.
1255+ const void
1256+ *pPattern, // /< [in] pointer with the bytes of the pattern to set.
1257+ size_t
1258+ size, // /< [in] size in bytes to be set. Must be a multiple of patternSize.
1259+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1260+ const ur_event_handle_t *
1261+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1262+ // /< events that must be complete before this command can be executed.
1263+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that this
1264+ // /< command does not wait on any event to complete.
1265+ ur_event_handle_t *
1266+ phEvent // /< [out][optional] return an event object that identifies this particular
1267+ // /< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1268+ // /< must not refer to an element of the phEventWaitList array.
1269+ ) {
1270+ auto pfnUSMFill = getContext ()->urDdiTable .Enqueue .pfnUSMFill ;
1271+ getContext ()->logger .debug (" ==== urEnqueueUSMFill" );
1272+
1273+ ur_event_handle_t hEvents[2 ] = {};
1274+ UR_CALL (pfnUSMFill (hQueue, pMem, patternSize, pPattern, size,
1275+ numEventsInWaitList, phEventWaitList, &hEvents[0 ]));
1276+
1277+ const auto Mem = (uptr)pMem;
1278+ auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
1279+ if (MemInfoItOp) {
1280+ auto MemInfo = (*MemInfoItOp)->second ;
1281+
1282+ const auto &DeviceInfo =
1283+ getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1284+ const auto MemShadow = DeviceInfo->Shadow ->MemToShadow (Mem);
1285+
1286+ UR_CALL (EnqueueUSMBlockingSet (hQueue, (void *)MemShadow, 0 , size, 0 ,
1287+ nullptr , &hEvents[1 ]));
1288+ }
1289+
1290+ if (phEvent) {
1291+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1292+ hQueue, 2 , hEvents, phEvent));
1293+ }
1294+
1295+ return UR_RESULT_SUCCESS;
1296+ }
1297+
1298+ // /////////////////////////////////////////////////////////////////////////////
1299+ // / @brief Intercept function for urEnqueueUSMMemcpy
1300+ ur_result_t UR_APICALL urEnqueueUSMMemcpy (
1301+ ur_queue_handle_t hQueue, // /< [in] handle of the queue object
1302+ bool blocking, // /< [in] blocking or non-blocking copy
1303+ void *
1304+ pDst, // /< [in][bounds(0, size)] pointer to the destination USM memory object
1305+ const void *
1306+ pSrc, // /< [in][bounds(0, size)] pointer to the source USM memory object
1307+ size_t size, // /< [in] size in bytes to be copied
1308+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1309+ const ur_event_handle_t *
1310+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1311+ // /< events that must be complete before this command can be executed.
1312+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that this
1313+ // /< command does not wait on any event to complete.
1314+ ur_event_handle_t *
1315+ phEvent // /< [out][optional] return an event object that identifies this particular
1316+ // /< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1317+ // /< must not refer to an element of the phEventWaitList array.
1318+ ) {
1319+ auto pfnUSMMemcpy = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy ;
1320+ getContext ()->logger .debug (" ==== pfnUSMMemcpy" );
1321+
1322+ ur_event_handle_t hEvents[2 ] = {};
1323+ UR_CALL (pfnUSMMemcpy (hQueue, blocking, pDst, pSrc, size,
1324+ numEventsInWaitList, phEventWaitList, &hEvents[0 ]));
1325+
1326+ const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1327+ auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
1328+ auto DstInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Dst);
1329+
1330+ if (SrcInfoItOp && DstInfoItOp) {
1331+ auto SrcInfo = (*SrcInfoItOp)->second ;
1332+ auto DstInfo = (*DstInfoItOp)->second ;
1333+
1334+ const auto &DeviceInfo =
1335+ getMsanInterceptor ()->getDeviceInfo (SrcInfo->Device );
1336+ const auto SrcShadow = DeviceInfo->Shadow ->MemToShadow (Src);
1337+ const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1338+
1339+ UR_CALL (pfnUSMMemcpy (hQueue, blocking, (void *)DstShadow,
1340+ (void *)SrcShadow, size, 0 , nullptr , &hEvents[1 ]));
1341+ } else if (DstInfoItOp) {
1342+ auto DstInfo = (*DstInfoItOp)->second ;
1343+
1344+ const auto &DeviceInfo =
1345+ getMsanInterceptor ()->getDeviceInfo (DstInfo->Device );
1346+ auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1347+
1348+ UR_CALL (EnqueueUSMBlockingSet (hQueue, (void *)DstShadow, 0 , size, 0 ,
1349+ nullptr , &hEvents[1 ]));
1350+ }
1351+
1352+ if (phEvent) {
1353+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1354+ hQueue, 2 , hEvents, phEvent));
1355+ }
1356+
1357+ return UR_RESULT_SUCCESS;
1358+ }
1359+
1360+ // /////////////////////////////////////////////////////////////////////////////
1361+ // / @brief Intercept function for urEnqueueUSMFill2D
1362+ ur_result_t UR_APICALL urEnqueueUSMFill2D (
1363+ ur_queue_handle_t hQueue, // /< [in] handle of the queue to submit to.
1364+ void *
1365+ pMem, // /< [in][bounds(0, pitch * height)] pointer to memory to be filled.
1366+ size_t
1367+ pitch, // /< [in] the total width of the destination memory including padding.
1368+ size_t
1369+ patternSize, // /< [in] the size in bytes of the pattern. Must be a power of 2 and less
1370+ // /< than or equal to width.
1371+ const void
1372+ *pPattern, // /< [in] pointer with the bytes of the pattern to set.
1373+ size_t
1374+ width, // /< [in] the width in bytes of each row to fill. Must be a multiple of
1375+ // /< patternSize.
1376+ size_t height, // /< [in] the height of the columns to fill.
1377+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1378+ const ur_event_handle_t *
1379+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1380+ // /< events that must be complete before the kernel execution.
1381+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1382+ ur_event_handle_t *
1383+ phEvent // /< [out][optional] return an event object that identifies this particular
1384+ // /< kernel execution instance. If phEventWaitList and phEvent are not
1385+ // /< NULL, phEvent must not refer to an element of the phEventWaitList array.
1386+ ) {
1387+ auto pfnUSMFill2D = getContext ()->urDdiTable .Enqueue .pfnUSMFill2D ;
1388+ getContext ()->logger .debug (" ==== urEnqueueUSMFill2D" );
1389+
1390+ ur_event_handle_t hEvents[2 ] = {};
1391+ UR_CALL (pfnUSMFill2D (hQueue, pMem, pitch, patternSize, pPattern, width,
1392+ height, numEventsInWaitList, phEventWaitList,
1393+ &hEvents[0 ]));
1394+
1395+ const auto Mem = (uptr)pMem;
1396+ auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
1397+ if (MemInfoItOp) {
1398+ auto MemInfo = (*MemInfoItOp)->second ;
1399+
1400+ const auto &DeviceInfo =
1401+ getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1402+ const auto MemShadow = DeviceInfo->Shadow ->MemToShadow (Mem);
1403+
1404+ const char Pattern = 0 ;
1405+ UR_CALL (pfnUSMFill2D (hQueue, (void *)MemShadow, pitch, 1 , &Pattern,
1406+ width, height, 0 , nullptr , &hEvents[1 ]));
1407+ }
1408+
1409+ if (phEvent) {
1410+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1411+ hQueue, 2 , hEvents, phEvent));
1412+ }
1413+
1414+ return UR_RESULT_SUCCESS;
1415+ }
1416+
1417+ // /////////////////////////////////////////////////////////////////////////////
1418+ // / @brief Intercept function for urEnqueueUSMMemcpy2D
1419+ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D (
1420+ ur_queue_handle_t hQueue, // /< [in] handle of the queue to submit to.
1421+ bool blocking, // /< [in] indicates if this operation should block the host.
1422+ void *
1423+ pDst, // /< [in][bounds(0, dstPitch * height)] pointer to memory where data will
1424+ // /< be copied.
1425+ size_t
1426+ dstPitch, // /< [in] the total width of the source memory including padding.
1427+ const void *
1428+ pSrc, // /< [in][bounds(0, srcPitch * height)] pointer to memory to be copied.
1429+ size_t
1430+ srcPitch, // /< [in] the total width of the source memory including padding.
1431+ size_t width, // /< [in] the width in bytes of each row to be copied.
1432+ size_t height, // /< [in] the height of columns to be copied.
1433+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1434+ const ur_event_handle_t *
1435+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1436+ // /< events that must be complete before the kernel execution.
1437+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1438+ ur_event_handle_t *
1439+ phEvent // /< [out][optional] return an event object that identifies this particular
1440+ // /< kernel execution instance. If phEventWaitList and phEvent are not
1441+ // /< NULL, phEvent must not refer to an element of the phEventWaitList array.
1442+ ) {
1443+ auto pfnUSMMemcpy2D = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy2D ;
1444+ getContext ()->logger .debug (" ==== pfnUSMMemcpy2D" );
1445+
1446+ ur_event_handle_t hEvents[2 ] = {};
1447+ UR_CALL (pfnUSMMemcpy2D (hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
1448+ width, height, numEventsInWaitList, phEventWaitList,
1449+ &hEvents[0 ]));
1450+
1451+ const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1452+ auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
1453+ auto DstInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Dst);
1454+
1455+ if (SrcInfoItOp && DstInfoItOp) {
1456+ auto SrcInfo = (*SrcInfoItOp)->second ;
1457+ auto DstInfo = (*DstInfoItOp)->second ;
1458+
1459+ const auto &DeviceInfo =
1460+ getMsanInterceptor ()->getDeviceInfo (SrcInfo->Device );
1461+ const auto SrcShadow = DeviceInfo->Shadow ->MemToShadow (Src);
1462+ const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1463+
1464+ UR_CALL (pfnUSMMemcpy2D (hQueue, blocking, (void *)DstShadow, dstPitch,
1465+ (void *)SrcShadow, srcPitch, width, height, 0 ,
1466+ nullptr , &hEvents[1 ]));
1467+ } else if (DstInfoItOp) {
1468+ auto DstInfo = (*DstInfoItOp)->second ;
1469+
1470+ const auto &DeviceInfo =
1471+ getMsanInterceptor ()->getDeviceInfo (DstInfo->Device );
1472+ const auto DstShadow = DeviceInfo->Shadow ->MemToShadow (Dst);
1473+
1474+ const char Pattern = 0 ;
1475+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnUSMFill2D (
1476+ hQueue, (void *)DstShadow, dstPitch, 1 , &Pattern, width, height, 0 ,
1477+ nullptr , &hEvents[1 ]));
1478+ }
1479+
1480+ if (phEvent) {
1481+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
1482+ hQueue, 2 , hEvents, phEvent));
1483+ }
1484+
1485+ return UR_RESULT_SUCCESS;
1486+ }
1487+
12371488// /////////////////////////////////////////////////////////////////////////////
12381489// / @brief Exported function for filling application's Global table
12391490// / with current process' addresses
@@ -1391,6 +1642,10 @@ ur_result_t urGetEnqueueProcAddrTable(
13911642 pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap;
13921643 pDdiTable->pfnKernelLaunch =
13931644 ur_sanitizer_layer::msan::urEnqueueKernelLaunch;
1645+ pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill;
1646+ pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy;
1647+ pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D;
1648+ pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D;
13941649
13951650 return result;
13961651}
@@ -1408,6 +1663,7 @@ ur_result_t urGetUSMProcAddrTable(
14081663 ur_result_t result = UR_RESULT_SUCCESS;
14091664
14101665 pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
1666+ pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;
14111667
14121668 return result;
14131669}
0 commit comments