@@ -110,6 +110,17 @@ ur_result_t urUSMDeviceAlloc(
110110 pool, size, ppMem);
111111}
112112
113+ // /////////////////////////////////////////////////////////////////////////////
114+ // / @brief Intercept function for urUSMFree
115+ __urdlllocal ur_result_t UR_APICALL urUSMFree (
116+ ur_context_handle_t hContext, // /< [in] handle of the context object
117+ void *pMem // /< [in] pointer to USM memory object
118+ ) {
119+ getContext ()->logger .debug (" ==== urUSMFree" );
120+
121+ return getMsanInterceptor ()->releaseMemory (hContext, pMem);
122+ }
123+
113124// /////////////////////////////////////////////////////////////////////////////
114125// / @brief Intercept function for urProgramCreateWithIL
115126ur_result_t urProgramCreateWithIL (
@@ -1271,6 +1282,176 @@ ur_result_t urKernelSetArgMemObj(
12711282 return UR_RESULT_SUCCESS;
12721283}
12731284
1285+ // /////////////////////////////////////////////////////////////////////////////
1286+ // / @brief Intercept function for urEnqueueUSMFill
1287+ ur_result_t UR_APICALL urEnqueueUSMFill (
1288+ ur_queue_handle_t hQueue, // /< [in] handle of the queue object
1289+ void *pMem, // /< [in][bounds(0, size)] pointer to USM memory object
1290+ size_t
1291+ patternSize, // /< [in] the size in bytes of the pattern. Must be a power of 2 and less
1292+ // /< than or equal to width.
1293+ const void
1294+ *pPattern, // /< [in] pointer with the bytes of the pattern to set.
1295+ size_t
1296+ size, // /< [in] size in bytes to be set. Must be a multiple of patternSize.
1297+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1298+ const ur_event_handle_t *
1299+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1300+ // /< events that must be complete before this command can be executed.
1301+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that this
1302+ // /< command does not wait on any event to complete.
1303+ ur_event_handle_t *
1304+ phEvent // /< [out][optional] return an event object that identifies this particular
1305+ // /< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1306+ // /< must not refer to an element of the phEventWaitList array.
1307+ ) {
1308+ auto pfnUSMFill = getContext ()->urDdiTable .Enqueue .pfnUSMFill ;
1309+
1310+ getContext ()->logger .debug (" ==== urEnqueueUSMFill" );
1311+
1312+ auto Mem = (uptr)pMem;
1313+ auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
1314+ if (MemInfoItOp) {
1315+ auto MemInfo = (*MemInfoItOp)->second ;
1316+
1317+ const auto &DeviceInfo =
1318+ getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1319+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (
1320+ hQueue, Mem, size, 0 , numEventsInWaitList, phEventWaitList,
1321+ phEvent));
1322+ }
1323+
1324+ return pfnUSMFill (hQueue, pMem, patternSize, pPattern, size,
1325+ numEventsInWaitList, phEventWaitList, phEvent);
1326+ }
1327+
1328+ // /////////////////////////////////////////////////////////////////////////////
1329+ // / @brief Intercept function for urEnqueueUSMMemcpy
1330+ ur_result_t UR_APICALL urEnqueueUSMMemcpy (
1331+ ur_queue_handle_t hQueue, // /< [in] handle of the queue object
1332+ bool blocking, // /< [in] blocking or non-blocking copy
1333+ void *
1334+ pDst, // /< [in][bounds(0, size)] pointer to the destination USM memory object
1335+ const void *
1336+ pSrc, // /< [in][bounds(0, size)] pointer to the source USM memory object
1337+ size_t size, // /< [in] size in bytes to be copied
1338+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1339+ const ur_event_handle_t *
1340+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1341+ // /< events that must be complete before this command can be executed.
1342+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that this
1343+ // /< command does not wait on any event to complete.
1344+ ur_event_handle_t *
1345+ phEvent // /< [out][optional] return an event object that identifies this particular
1346+ // /< command instance. If phEventWaitList and phEvent are not NULL, phEvent
1347+ // /< must not refer to an element of the phEventWaitList array.
1348+ ) {
1349+ auto pfnUSMMemcpy = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy ;
1350+ getContext ()->logger .debug (" ==== pfnUSMMemcpy" );
1351+
1352+ auto Src = (uptr)pSrc, Dst = (uptr)pDst;
1353+ auto SrcInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Src);
1354+ auto DstInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Dst);
1355+
1356+ if (SrcInfoItOp && DstInfoItOp) {
1357+ auto SrcInfo = (*SrcInfoItOp)->second ;
1358+ auto DstInfo = (*DstInfoItOp)->second ;
1359+
1360+ const auto &DeviceInfo =
1361+ getMsanInterceptor ()->getDeviceInfo (SrcInfo->Device );
1362+ UR_CALL (DeviceInfo->Shadow ->EnqueueCopyShadow (
1363+ hQueue, blocking, Dst, Src, size, numEventsInWaitList,
1364+ phEventWaitList, phEvent));
1365+ } else if (DstInfoItOp) {
1366+ auto DstInfo = (*DstInfoItOp)->second ;
1367+
1368+ const auto &DeviceInfo =
1369+ getMsanInterceptor ()->getDeviceInfo (DstInfo->Device );
1370+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (
1371+ hQueue, Dst, size, 0 , numEventsInWaitList, phEventWaitList,
1372+ phEvent));
1373+ }
1374+
1375+ return pfnUSMMemcpy (hQueue, blocking, pDst, pSrc, size, numEventsInWaitList,
1376+ phEventWaitList, phEvent);
1377+ }
1378+
1379+ // /////////////////////////////////////////////////////////////////////////////
1380+ // / @brief Intercept function for urEnqueueUSMFill2D
1381+ ur_result_t UR_APICALL urEnqueueUSMFill2D (
1382+ ur_queue_handle_t hQueue, // /< [in] handle of the queue to submit to.
1383+ void *
1384+ pMem, // /< [in][bounds(0, pitch * height)] pointer to memory to be filled.
1385+ size_t
1386+ pitch, // /< [in] the total width of the destination memory including padding.
1387+ size_t
1388+ patternSize, // /< [in] the size in bytes of the pattern. Must be a power of 2 and less
1389+ // /< than or equal to width.
1390+ const void
1391+ *pPattern, // /< [in] pointer with the bytes of the pattern to set.
1392+ size_t
1393+ width, // /< [in] the width in bytes of each row to fill. Must be a multiple of
1394+ // /< patternSize.
1395+ size_t height, // /< [in] the height of the columns to fill.
1396+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1397+ const ur_event_handle_t *
1398+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1399+ // /< events that must be complete before the kernel execution.
1400+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1401+ ur_event_handle_t *
1402+ phEvent // /< [out][optional] return an event object that identifies this particular
1403+ // /< kernel execution instance. If phEventWaitList and phEvent are not
1404+ // /< NULL, phEvent must not refer to an element of the phEventWaitList array.
1405+ ) {
1406+ auto pfnUSMFill2D = getContext ()->urDdiTable .Enqueue .pfnUSMFill2D ;
1407+ getContext ()->logger .debug (" ==== urEnqueueUSMFill2D" );
1408+
1409+ auto Mem = (uptr)pMem;
1410+ auto MemInfoItOp = getMsanInterceptor ()->findAllocInfoByAddress (Mem);
1411+ if (MemInfoItOp) {
1412+ auto MemInfo = (*MemInfoItOp)->second ;
1413+
1414+ const auto &DeviceInfo =
1415+ getMsanInterceptor ()->getDeviceInfo (MemInfo->Device );
1416+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (
1417+ hQueue, Mem, width * height, 0 , numEventsInWaitList,
1418+ phEventWaitList, phEvent));
1419+ }
1420+
1421+ return pfnUSMFill2D (hQueue, pMem, pitch, patternSize, pPattern, width,
1422+ height, numEventsInWaitList, phEventWaitList, phEvent);
1423+ }
1424+
1425+ // /////////////////////////////////////////////////////////////////////////////
1426+ // / @brief Intercept function for urEnqueueUSMMemcpy2D
1427+ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D (
1428+ ur_queue_handle_t hQueue, // /< [in] handle of the queue to submit to.
1429+ bool blocking, // /< [in] indicates if this operation should block the host.
1430+ void *
1431+ pDst, // /< [in][bounds(0, dstPitch * height)] pointer to memory where data will
1432+ // /< be copied.
1433+ size_t
1434+ dstPitch, // /< [in] the total width of the source memory including padding.
1435+ const void *
1436+ pSrc, // /< [in][bounds(0, srcPitch * height)] pointer to memory to be copied.
1437+ size_t
1438+ srcPitch, // /< [in] the total width of the source memory including padding.
1439+ size_t width, // /< [in] the width in bytes of each row to be copied.
1440+ size_t height, // /< [in] the height of columns to be copied.
1441+ uint32_t numEventsInWaitList, // /< [in] size of the event wait list
1442+ const ur_event_handle_t *
1443+ phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1444+ // /< events that must be complete before the kernel execution.
1445+ // /< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event.
1446+ ur_event_handle_t *
1447+ phEvent // /< [out][optional] return an event object that identifies this particular
1448+ // /< kernel execution instance. If phEventWaitList and phEvent are not
1449+ // /< NULL, phEvent must not refer to an element of the phEventWaitList array.
1450+ ) {
1451+ ur_result_t result = UR_RESULT_SUCCESS;
1452+ return result;
1453+ }
1454+
12741455// /////////////////////////////////////////////////////////////////////////////
12751456// / @brief Exported function for filling application's Global table
12761457// / with current process' addresses
@@ -1429,6 +1610,10 @@ ur_result_t urGetEnqueueProcAddrTable(
14291610 pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap;
14301611 pDdiTable->pfnKernelLaunch =
14311612 ur_sanitizer_layer::msan::urEnqueueKernelLaunch;
1613+ pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill;
1614+ pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy;
1615+ pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D;
1616+ pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D;
14321617
14331618 return result;
14341619}
@@ -1446,6 +1631,7 @@ ur_result_t urGetUSMProcAddrTable(
14461631 ur_result_t result = UR_RESULT_SUCCESS;
14471632
14481633 pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc;
1634+ pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree;
14491635
14501636 return result;
14511637}
0 commit comments