@@ -411,6 +411,29 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
411411 return result;
412412}
413413
414+ // /////////////////////////////////////////////////////////////////////////////
415+ // / @brief Intercept function for urContextRetain
416+ __urdlllocal ur_result_t UR_APICALL urContextRetain (
417+ ur_context_handle_t
418+ hContext // /< [in] handle of the context to get a reference of.
419+ ) {
420+ auto pfnRetain = getContext ()->urDdiTable .Context .pfnRetain ;
421+
422+ if (nullptr == pfnRetain) {
423+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
424+ }
425+
426+ getContext ()->logger .debug (" ==== urContextRetain" );
427+
428+ UR_CALL (pfnRetain (hContext));
429+
430+ auto ContextInfo = getContext ()->interceptor ->getContextInfo (hContext);
431+ UR_ASSERT (ContextInfo != nullptr , UR_RESULT_ERROR_INVALID_VALUE);
432+ ContextInfo->RefCount ++;
433+
434+ return UR_RESULT_SUCCESS;
435+ }
436+
414437// /////////////////////////////////////////////////////////////////////////////
415438// / @brief Intercept function for urContextRelease
416439__urdlllocal ur_result_t UR_APICALL urContextRelease (
@@ -424,10 +447,15 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
424447
425448 getContext ()->logger .debug (" ==== urContextRelease" );
426449
427- UR_CALL (getContext ()->interceptor ->eraseContext (hContext));
428- ur_result_t result = pfnRelease (hContext);
450+ UR_CALL (pfnRelease (hContext));
429451
430- return result;
452+ auto ContextInfo = getContext ()->interceptor ->getContextInfo (hContext);
453+ UR_ASSERT (ContextInfo != nullptr , UR_RESULT_ERROR_INVALID_VALUE);
454+ if (--ContextInfo->RefCount == 0 ) {
455+ UR_CALL (getContext ()->interceptor ->eraseContext (hContext));
456+ }
457+
458+ return UR_RESULT_SUCCESS;
431459}
432460
433461// /////////////////////////////////////////////////////////////////////////////
@@ -1207,9 +1235,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
12071235
12081236 UR_CALL (pfnRetain (hKernel));
12091237
1210- if ( auto KernelInfo = getContext ()->interceptor ->getKernelInfo (hKernel)) {
1211- KernelInfo-> RefCount ++ ;
1212- }
1238+ auto KernelInfo = getContext ()->interceptor ->getKernelInfo (hKernel);
1239+ UR_ASSERT (KernelInfo != nullptr , UR_RESULT_ERROR_INVALID_VALUE) ;
1240+ KernelInfo-> RefCount ++;
12131241
12141242 return UR_RESULT_SUCCESS;
12151243}
@@ -1228,10 +1256,9 @@ __urdlllocal ur_result_t urKernelRelease(
12281256 getContext ()->logger .debug (" ==== urKernelRelease" );
12291257 UR_CALL (pfnRelease (hKernel));
12301258
1231- if (auto KernelInfo = getContext ()->interceptor ->getKernelInfo (hKernel)) {
1232- if (--KernelInfo->RefCount != 0 ) {
1233- return UR_RESULT_SUCCESS;
1234- }
1259+ auto KernelInfo = getContext ()->interceptor ->getKernelInfo (hKernel);
1260+ UR_ASSERT (KernelInfo != nullptr , UR_RESULT_ERROR_INVALID_VALUE);
1261+ if (--KernelInfo->RefCount == 0 ) {
12351262 UR_CALL (getContext ()->interceptor ->eraseKernel (hKernel));
12361263 }
12371264
@@ -1426,6 +1453,7 @@ __urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable(
14261453 ur_result_t result = UR_RESULT_SUCCESS;
14271454
14281455 pDdiTable->pfnCreate = ur_sanitizer_layer::urContextCreate;
1456+ pDdiTable->pfnRetain = ur_sanitizer_layer::urContextRetain;
14291457 pDdiTable->pfnRelease = ur_sanitizer_layer::urContextRelease;
14301458
14311459 pDdiTable->pfnCreateWithNativeHandle =
0 commit comments