@@ -46,18 +46,36 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
4646 ur_device_handle_t Device,
4747 const ur_usm_desc_t *Properties,
4848 ur_usm_pool_handle_t Pool,
49- size_t Size, void **ResultPtr) {
49+ size_t Size, AllocType Type,
50+ void **ResultPtr) {
5051
5152 auto ContextInfo = getContextInfo (Context);
52- std::shared_ptr<DeviceInfo> DeviceInfo = getDeviceInfo (Device);
53+ std::shared_ptr<DeviceInfo> DeviceInfo =
54+ Device ? getDeviceInfo (Device) : nullptr ;
5355
5456 void *Allocated = nullptr ;
5557
56- UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
57- Context, Device, Properties, Pool, Size, &Allocated));
58+ if (Type == AllocType::DEVICE_USM) {
59+ UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
60+ Context, Device, Properties, Pool, Size, &Allocated));
61+ } else if (Type == AllocType::HOST_USM) {
62+ UR_CALL (getContext ()->urDdiTable .USM .pfnHostAlloc (
63+ Context, Properties, Pool, Size, &Allocated));
64+ } else if (Type == AllocType::SHARED_USM) {
65+ UR_CALL (getContext ()->urDdiTable .USM .pfnSharedAlloc (
66+ Context, Device, Properties, Pool, Size, &Allocated));
67+ }
5868
5969 *ResultPtr = Allocated;
6070
71+ ContextInfo->MaxAllocatedSize =
72+ std::max (ContextInfo->MaxAllocatedSize , Size);
73+
74+ // For host/shared usm, we only record the alloc size.
75+ if (Type != AllocType::DEVICE_USM) {
76+ return UR_RESULT_SUCCESS;
77+ }
78+
6179 auto AI =
6280 std::make_shared<MsanAllocInfo>(MsanAllocInfo{(uptr)Allocated,
6381 Size,
@@ -432,10 +450,14 @@ ur_result_t MsanInterceptor::prepareLaunch(
432450 }
433451
434452 // Set LaunchInfo
453+ auto ContextInfo = getContextInfo (LaunchInfo.Context );
435454 LaunchInfo.Data ->GlobalShadowOffset = DeviceInfo->Shadow ->ShadowBegin ;
436455 LaunchInfo.Data ->GlobalShadowOffsetEnd = DeviceInfo->Shadow ->ShadowEnd ;
437456 LaunchInfo.Data ->DeviceTy = DeviceInfo->Type ;
438457 LaunchInfo.Data ->Debug = getOptions ().Debug ? 1 : 0 ;
458+ UR_CALL (getContext ()->urDdiTable .USM .pfnDeviceAlloc (
459+ ContextInfo->Handle , DeviceInfo->Handle , nullptr , nullptr ,
460+ ContextInfo->MaxAllocatedSize , &LaunchInfo.Data ->CleanShadow ));
439461
440462 getContext ()->logger .info (
441463 " launch_info {} (GlobalShadow={}, Device={}, Debug={})" ,
@@ -518,6 +540,11 @@ ur_result_t USMLaunchInfo::initialize() {
518540USMLaunchInfo::~USMLaunchInfo () {
519541 [[maybe_unused]] ur_result_t Result;
520542 if (Data) {
543+ if (Data->CleanShadow ) {
544+ Result = getContext ()->urDdiTable .USM .pfnFree (Context,
545+ Data->CleanShadow );
546+ assert (Result == UR_RESULT_SUCCESS);
547+ }
521548 Result = getContext ()->urDdiTable .USM .pfnFree (Context, (void *)Data);
522549 assert (Result == UR_RESULT_SUCCESS);
523550 }
0 commit comments