@@ -144,6 +144,12 @@ ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) {
144144 return Result;
145145 }
146146
147+ getContext ()->logger .info (" registerDeviceGlobals" );
148+ Result = registerDeviceGlobals (Program);
149+ if (Result != UR_RESULT_SUCCESS) {
150+ return Result;
151+ }
152+
147153 return Result;
148154}
149155
@@ -212,6 +218,53 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
212218 return UR_RESULT_SUCCESS;
213219}
214220
221+ ur_result_t
222+ MsanInterceptor::registerDeviceGlobals (ur_program_handle_t Program) {
223+ std::vector<ur_device_handle_t > Devices = GetDevices (Program);
224+ assert (Devices.size () != 0 && " No devices in registerDeviceGlobals" );
225+ auto Context = GetContext (Program);
226+ auto ContextInfo = getContextInfo (Context);
227+ auto ProgramInfo = getProgramInfo (Program);
228+ assert (ProgramInfo != nullptr && " unregistered program!" );
229+
230+ for (auto Device : Devices) {
231+ ManagedQueue Queue (Context, Device);
232+
233+ size_t MetadataSize;
234+ void *MetadataPtr;
235+ auto Result =
236+ getContext ()->urDdiTable .Program .pfnGetGlobalVariablePointer (
237+ Device, Program, kSPIR_MsanDeviceGlobalMetadata , &MetadataSize,
238+ &MetadataPtr);
239+ if (Result != UR_RESULT_SUCCESS) {
240+ getContext ()->logger .info (" No device globals" );
241+ continue ;
242+ }
243+
244+ const uint64_t NumOfDeviceGlobal =
245+ MetadataSize / sizeof (DeviceGlobalInfo);
246+ assert ((MetadataSize % sizeof (DeviceGlobalInfo) == 0 ) &&
247+ " DeviceGlobal metadata size is not correct" );
248+ std::vector<DeviceGlobalInfo> GVInfos (NumOfDeviceGlobal);
249+ Result = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
250+ Queue, true , &GVInfos[0 ], MetadataPtr,
251+ sizeof (DeviceGlobalInfo) * NumOfDeviceGlobal, 0 , nullptr , nullptr );
252+ if (Result != UR_RESULT_SUCCESS) {
253+ getContext ()->logger .error (" Device Global[{}] Read Failed: {}" ,
254+ kSPIR_MsanDeviceGlobalMetadata , Result);
255+ return Result;
256+ }
257+
258+ auto DeviceInfo = getMsanInterceptor ()->getDeviceInfo (Device);
259+ for (size_t i = 0 ; i < NumOfDeviceGlobal; i++) {
260+ const auto &GVInfo = GVInfos[i];
261+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (Queue, GVInfo.Addr , GVInfo.Size , 0 ));
262+ }
263+ }
264+
265+ return UR_RESULT_SUCCESS;
266+ }
267+
215268ur_result_t MsanInterceptor::insertContext (ur_context_handle_t Context,
216269 std::shared_ptr<ContextInfo> &CI) {
217270 std::scoped_lock<ur_shared_mutex> Guard (m_ContextMapMutex);
0 commit comments