@@ -144,6 +144,12 @@ ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) {
144
144
return Result;
145
145
}
146
146
147
+ getContext ()->logger .info (" registerDeviceGlobals" );
148
+ Result = registerDeviceGlobals (Program);
149
+ if (Result != UR_RESULT_SUCCESS) {
150
+ return Result;
151
+ }
152
+
147
153
return Result;
148
154
}
149
155
@@ -212,6 +218,53 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
212
218
return UR_RESULT_SUCCESS;
213
219
}
214
220
221
+ ur_result_t
222
+ MsanInterceptor::registerDeviceGlobals (ur_program_handle_t Program) {
223
+ std::vector<ur_device_handle_t > Devices = GetDevices (Program);
224
+ assert (Devices.size () != 0 && " No devices in registerDeviceGlobals" );
225
+ auto Context = GetContext (Program);
226
+ auto ContextInfo = getContextInfo (Context);
227
+ auto ProgramInfo = getProgramInfo (Program);
228
+ assert (ProgramInfo != nullptr && " unregistered program!" );
229
+
230
+ for (auto Device : Devices) {
231
+ ManagedQueue Queue (Context, Device);
232
+
233
+ size_t MetadataSize;
234
+ void *MetadataPtr;
235
+ auto Result =
236
+ getContext ()->urDdiTable .Program .pfnGetGlobalVariablePointer (
237
+ Device, Program, kSPIR_MsanDeviceGlobalMetadata , &MetadataSize,
238
+ &MetadataPtr);
239
+ if (Result != UR_RESULT_SUCCESS) {
240
+ getContext ()->logger .info (" No device globals" );
241
+ continue ;
242
+ }
243
+
244
+ const uint64_t NumOfDeviceGlobal =
245
+ MetadataSize / sizeof (DeviceGlobalInfo);
246
+ assert ((MetadataSize % sizeof (DeviceGlobalInfo) == 0 ) &&
247
+ " DeviceGlobal metadata size is not correct" );
248
+ std::vector<DeviceGlobalInfo> GVInfos (NumOfDeviceGlobal);
249
+ Result = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
250
+ Queue, true , &GVInfos[0 ], MetadataPtr,
251
+ sizeof (DeviceGlobalInfo) * NumOfDeviceGlobal, 0 , nullptr , nullptr );
252
+ if (Result != UR_RESULT_SUCCESS) {
253
+ getContext ()->logger .error (" Device Global[{}] Read Failed: {}" ,
254
+ kSPIR_MsanDeviceGlobalMetadata , Result);
255
+ return Result;
256
+ }
257
+
258
+ auto DeviceInfo = getMsanInterceptor ()->getDeviceInfo (Device);
259
+ for (size_t i = 0 ; i < NumOfDeviceGlobal; i++) {
260
+ const auto &GVInfo = GVInfos[i];
261
+ UR_CALL (DeviceInfo->Shadow ->EnqueuePoisonShadow (Queue, GVInfo.Addr , GVInfo.Size , 0 ));
262
+ }
263
+ }
264
+
265
+ return UR_RESULT_SUCCESS;
266
+ }
267
+
215
268
ur_result_t MsanInterceptor::insertContext (ur_context_handle_t Context,
216
269
std::shared_ptr<ContextInfo> &CI) {
217
270
std::scoped_lock<ur_shared_mutex> Guard (m_ContextMapMutex);
0 commit comments