@@ -217,6 +217,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(
217
217
USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
218
218
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
219
219
workDim);
220
+ UR_CALL (LaunchInfo.initialize ());
220
221
221
222
UR_CALL (context.interceptor ->preLaunchKernel (hKernel, hQueue, LaunchInfo));
222
223
@@ -317,6 +318,90 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
317
318
return result;
318
319
}
319
320
321
+ // /////////////////////////////////////////////////////////////////////////////
322
+ // / @brief Intercept function for urKernelCreate
323
+ __urdlllocal ur_result_t UR_APICALL urKernelCreate (
324
+ ur_program_handle_t hProgram, // /< [in] handle of the program instance
325
+ const char *pKernelName, // /< [in] pointer to null-terminated string.
326
+ ur_kernel_handle_t
327
+ *phKernel // /< [out] pointer to handle of kernel object created.
328
+ ) {
329
+ auto pfnCreate = context.urDdiTable .Kernel .pfnCreate ;
330
+
331
+ if (nullptr == pfnCreate) {
332
+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
333
+ }
334
+
335
+ context.logger .debug (" ==== urKernelCreate" );
336
+
337
+ UR_CALL (pfnCreate (hProgram, pKernelName, phKernel));
338
+ UR_CALL (context.interceptor ->insertKernel (*phKernel));
339
+
340
+ return UR_RESULT_SUCCESS;
341
+ }
342
+
343
+ // /////////////////////////////////////////////////////////////////////////////
344
+ // / @brief Intercept function for urKernelRelease
345
+ __urdlllocal ur_result_t urKernelRelease (
346
+ ur_kernel_handle_t hKernel // /< [in] handle for the Kernel to release
347
+ ) {
348
+ auto pfnRelease = context.urDdiTable .Kernel .pfnRelease ;
349
+
350
+ if (nullptr == pfnRelease) {
351
+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
352
+ }
353
+
354
+ context.logger .debug (" ==== urKernelRelease" );
355
+ UR_CALL (pfnRelease (hKernel));
356
+
357
+ if (auto KernelInfo = context.interceptor ->getKernelInfo (hKernel)) {
358
+ uint32_t RefCount;
359
+ UR_CALL (context.urDdiTable .Kernel .pfnGetInfo (
360
+ hKernel, UR_KERNEL_INFO_REFERENCE_COUNT, sizeof (RefCount),
361
+ &RefCount, nullptr ));
362
+ if (RefCount == 1 ) {
363
+ UR_CALL (context.interceptor ->eraseKernel (hKernel));
364
+ }
365
+ }
366
+
367
+ return UR_RESULT_SUCCESS;
368
+ }
369
+
370
+ // /////////////////////////////////////////////////////////////////////////////
371
+ // / @brief Intercept function for urKernelSetArgLocal
372
+ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal (
373
+ ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
374
+ uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
375
+ size_t
376
+ argSize, // /< [in] size of the local buffer to be allocated by the runtime
377
+ const ur_kernel_arg_local_properties_t
378
+ *pProperties // /< [in][optional] pointer to local buffer properties.
379
+ ) {
380
+ auto pfnSetArgLocal = context.urDdiTable .Kernel .pfnSetArgLocal ;
381
+
382
+ if (nullptr == pfnSetArgLocal) {
383
+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
384
+ }
385
+
386
+ context.logger .debug (" ==== urKernelSetArgLocal (argIndex={}, argSize={})" ,
387
+ argIndex, argSize);
388
+
389
+ {
390
+ auto KI = context.interceptor ->getKernelInfo (hKernel);
391
+ std::scoped_lock<ur_shared_mutex> Guard (KI->Mutex );
392
+ // TODO: get local variable alignment
393
+ auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal (
394
+ argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
395
+ KI->LocalArgs [argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
396
+ argSize = argSizeWithRZ;
397
+ }
398
+
399
+ ur_result_t result =
400
+ pfnSetArgLocal (hKernel, argIndex, argSize, pProperties);
401
+
402
+ return result;
403
+ }
404
+
320
405
// /////////////////////////////////////////////////////////////////////////////
321
406
// / @brief Exported function for filling application's Context table
322
407
// / with current process' addresses
@@ -410,6 +495,38 @@ __urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
410
495
return result;
411
496
}
412
497
// /////////////////////////////////////////////////////////////////////////////
498
+ // / @brief Exported function for filling application's Kernel table
499
+ // / with current process' addresses
500
+ // /
501
+ // / @returns
502
+ // / - ::UR_RESULT_SUCCESS
503
+ // / - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
504
+ // / - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
505
+ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable (
506
+ ur_api_version_t version, // /< [in] API version requested
507
+ ur_kernel_dditable_t
508
+ *pDdiTable // /< [in,out] pointer to table of DDI function pointers
509
+ ) {
510
+ if (nullptr == pDdiTable) {
511
+ return UR_RESULT_ERROR_INVALID_NULL_POINTER;
512
+ }
513
+
514
+ if (UR_MAJOR_VERSION (ur_sanitizer_layer::context.version ) !=
515
+ UR_MAJOR_VERSION (version) ||
516
+ UR_MINOR_VERSION (ur_sanitizer_layer::context.version ) >
517
+ UR_MINOR_VERSION (version)) {
518
+ return UR_RESULT_ERROR_UNSUPPORTED_VERSION;
519
+ }
520
+
521
+ ur_result_t result = UR_RESULT_SUCCESS;
522
+
523
+ pDdiTable->pfnCreate = ur_sanitizer_layer::urKernelCreate;
524
+ pDdiTable->pfnRelease = ur_sanitizer_layer::urKernelRelease;
525
+ pDdiTable->pfnSetArgLocal = ur_sanitizer_layer::urKernelSetArgLocal;
526
+
527
+ return result;
528
+ }
529
+ // /////////////////////////////////////////////////////////////////////////////
413
530
// / @brief Exported function for filling application's Enqueue table
414
531
// / with current process' addresses
415
532
// /
@@ -509,6 +626,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable,
509
626
UR_API_VERSION_CURRENT, &dditable->Context );
510
627
}
511
628
629
+ if (UR_RESULT_SUCCESS == result) {
630
+ result = ur_sanitizer_layer::urGetKernelProcAddrTable (
631
+ UR_API_VERSION_CURRENT, &dditable->Kernel );
632
+ }
633
+
512
634
if (UR_RESULT_SUCCESS == result) {
513
635
result = ur_sanitizer_layer::urGetProgramProcAddrTable (
514
636
UR_API_VERSION_CURRENT, &dditable->Program );
0 commit comments