@@ -217,6 +217,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(
217217 USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
218218 pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
219219 workDim);
220+ UR_CALL (LaunchInfo.initialize ());
220221
221222 UR_CALL (context.interceptor ->preLaunchKernel (hKernel, hQueue, LaunchInfo));
222223
@@ -317,6 +318,90 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
317318 return result;
318319}
319320
321+ // /////////////////////////////////////////////////////////////////////////////
322+ // / @brief Intercept function for urKernelCreate
323+ __urdlllocal ur_result_t UR_APICALL urKernelCreate (
324+ ur_program_handle_t hProgram, // /< [in] handle of the program instance
325+ const char *pKernelName, // /< [in] pointer to null-terminated string.
326+ ur_kernel_handle_t
327+ *phKernel // /< [out] pointer to handle of kernel object created.
328+ ) {
329+ auto pfnCreate = context.urDdiTable .Kernel .pfnCreate ;
330+
331+ if (nullptr == pfnCreate) {
332+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
333+ }
334+
335+ context.logger .debug (" ==== urKernelCreate" );
336+
337+ UR_CALL (pfnCreate (hProgram, pKernelName, phKernel));
338+ UR_CALL (context.interceptor ->insertKernel (*phKernel));
339+
340+ return UR_RESULT_SUCCESS;
341+ }
342+
343+ // /////////////////////////////////////////////////////////////////////////////
344+ // / @brief Intercept function for urKernelRelease
345+ __urdlllocal ur_result_t urKernelRelease (
346+ ur_kernel_handle_t hKernel // /< [in] handle for the Kernel to release
347+ ) {
348+ auto pfnRelease = context.urDdiTable .Kernel .pfnRelease ;
349+
350+ if (nullptr == pfnRelease) {
351+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
352+ }
353+
354+ context.logger .debug (" ==== urKernelRelease" );
355+ UR_CALL (pfnRelease (hKernel));
356+
357+ if (auto KernelInfo = context.interceptor ->getKernelInfo (hKernel)) {
358+ uint32_t RefCount;
359+ UR_CALL (context.urDdiTable .Kernel .pfnGetInfo (
360+ hKernel, UR_KERNEL_INFO_REFERENCE_COUNT, sizeof (RefCount),
361+ &RefCount, nullptr ));
362+ if (RefCount == 1 ) {
363+ UR_CALL (context.interceptor ->eraseKernel (hKernel));
364+ }
365+ }
366+
367+ return UR_RESULT_SUCCESS;
368+ }
369+
370+ // /////////////////////////////////////////////////////////////////////////////
371+ // / @brief Intercept function for urKernelSetArgLocal
372+ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal (
373+ ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
374+ uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
375+ size_t
376+ argSize, // /< [in] size of the local buffer to be allocated by the runtime
377+ const ur_kernel_arg_local_properties_t
378+ *pProperties // /< [in][optional] pointer to local buffer properties.
379+ ) {
380+ auto pfnSetArgLocal = context.urDdiTable .Kernel .pfnSetArgLocal ;
381+
382+ if (nullptr == pfnSetArgLocal) {
383+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
384+ }
385+
386+ context.logger .debug (" ==== urKernelSetArgLocal (argIndex={}, argSize={})" ,
387+ argIndex, argSize);
388+
389+ {
390+ auto KI = context.interceptor ->getKernelInfo (hKernel);
391+ std::scoped_lock<ur_shared_mutex> Guard (KI->Mutex );
392+ // TODO: get local variable alignment
393+ auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal (
394+ argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
395+ KI->LocalArgs [argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
396+ argSize = argSizeWithRZ;
397+ }
398+
399+ ur_result_t result =
400+ pfnSetArgLocal (hKernel, argIndex, argSize, pProperties);
401+
402+ return result;
403+ }
404+
320405// /////////////////////////////////////////////////////////////////////////////
321406// / @brief Exported function for filling application's Context table
322407// / with current process' addresses
@@ -410,6 +495,38 @@ __urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
410495 return result;
411496}
412497// /////////////////////////////////////////////////////////////////////////////
498+ // / @brief Exported function for filling application's Kernel table
499+ // / with current process' addresses
500+ // /
501+ // / @returns
502+ // / - ::UR_RESULT_SUCCESS
503+ // / - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
504+ // / - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
505+ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable (
506+ ur_api_version_t version, // /< [in] API version requested
507+ ur_kernel_dditable_t
508+ *pDdiTable // /< [in,out] pointer to table of DDI function pointers
509+ ) {
510+ if (nullptr == pDdiTable) {
511+ return UR_RESULT_ERROR_INVALID_NULL_POINTER;
512+ }
513+
514+ if (UR_MAJOR_VERSION (ur_sanitizer_layer::context.version ) !=
515+ UR_MAJOR_VERSION (version) ||
516+ UR_MINOR_VERSION (ur_sanitizer_layer::context.version ) >
517+ UR_MINOR_VERSION (version)) {
518+ return UR_RESULT_ERROR_UNSUPPORTED_VERSION;
519+ }
520+
521+ ur_result_t result = UR_RESULT_SUCCESS;
522+
523+ pDdiTable->pfnCreate = ur_sanitizer_layer::urKernelCreate;
524+ pDdiTable->pfnRelease = ur_sanitizer_layer::urKernelRelease;
525+ pDdiTable->pfnSetArgLocal = ur_sanitizer_layer::urKernelSetArgLocal;
526+
527+ return result;
528+ }
529+ // /////////////////////////////////////////////////////////////////////////////
413530// / @brief Exported function for filling application's Enqueue table
414531// / with current process' addresses
415532// /
@@ -509,6 +626,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable,
509626 UR_API_VERSION_CURRENT, &dditable->Context );
510627 }
511628
629+ if (UR_RESULT_SUCCESS == result) {
630+ result = ur_sanitizer_layer::urGetKernelProcAddrTable (
631+ UR_API_VERSION_CURRENT, &dditable->Kernel );
632+ }
633+
512634 if (UR_RESULT_SUCCESS == result) {
513635 result = ur_sanitizer_layer::urGetProgramProcAddrTable (
514636 UR_API_VERSION_CURRENT, &dditable->Program );
0 commit comments