@@ -24,13 +24,37 @@ from templates import helper as th
2424namespace ur_loader
2525{
2626 % for obj in th.get_adapter_functions(specs):
27+ <%
28+ func_name = th.make_func_name(n, tags, obj)
29+ if func_name.startswith(x):
30+ func_basename = func_name[len (x):]
31+ else :
32+ func_basename = func_name
33+ %>
34+ % if func_basename == " EventSetCallback" :
35+ namespace {
36+ struct event_callback_wrapper_data_t {
37+ ${ x} _event_callback_t fn;
38+ ${ x} _event_handle_t event;
39+ void *userData;
40+ };
41+
42+ void event_callback_wrapper([[maybe_unused]] ${ x} _event_handle_t hEvent,
43+ ${ x} _execution_info_t execStatus, void *pUserData) {
44+ auto *wrapper =
45+ reinterpret_cast<event _callback_wrapper_data_t * >(pUserData);
46+ (wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
47+ delete wrapper;
48+ }
49+ }
50+
51+ %endif
2752 ///////////////////////////////////////////////////////////////////////////////
28- /// @brief Intercept function for ${ th.make_func_name(n, tags, obj) }
53+ /// @brief Intercept function for ${ func_name }
2954 % if ' condition' in obj:
3055 #if ${ th.subt(n, tags, obj[' condition' ])}
3156 %endif
32- __${ x} dlllocal ${ x} _result_t ${ X} _APICALL
33- ${ th.make_func_name(n, tags, obj)} (
57+ __${ x} dlllocal ${ x} _result_t ${ X} _APICALL ${ func_name} (
3458 % for line in th.make_param_lines(n, tags, obj):
3559 ${ line}
3660 %endfor
@@ -41,7 +65,16 @@ namespace ur_loader
4165 %> ${ th.get_initial_null_set(obj)}
4266
4367 [[maybe_unused]] auto context = getContext();
44- % if re.match(r " \w + AdapterGet$ " , th.make_func_name(n, tags, obj)):
68+ % if func_basename == " EventSetCallback" :
69+
70+ // Replace the callback with a wrapper function that gives the callback the loader event rather than a
71+ // backend-specific event
72+ auto wrapper_data =
73+ new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
74+ pUserData = wrapper_data;
75+ pfnNotify = event_callback_wrapper;
76+ %endif
77+ % if func_basename == " AdapterGet" :
4578
4679 size_t adapterIndex = 0;
4780 if( nullptr != ${ obj[' params' ][1 ][' name' ]} && ${ obj[' params' ][0 ][' name' ]} !=0)
@@ -74,7 +107,7 @@ namespace ur_loader
74107 *${ obj[' params' ][2 ][' name' ]} = static_cast<uint32 _t >(context->platforms.size());
75108 }
76109
77- % elif re.match( r " \w + PlatformGet$ " , th.make_func_name(n, tags, obj)) :
110+ % elif func_basename == " PlatformGet" :
78111 uint32_t total_platform_handle_count = 0;
79112
80113 for( uint32_t adapter_index = 0; adapter_index < ${ obj[' params' ][1 ][' name' ]} ; adapter_index++)
@@ -263,7 +296,7 @@ namespace ur_loader
263296 % for i, item in enumerate (epilogue):
264297 % if 0 == i and not item[' release' ] and not item[' retain' ] and not th.always_wrap_outputs(obj):
265298 ## TODO: Remove once we have a concrete way for submitting warnings in place.
266- % if re.match(r " urEnqueue \w + " , th.make_func_name(n, tags, obj) ):
299+ % if re.match(r " Enqueue \w + " , func_basename ):
267300 // In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
268301 if( ${ X} _RESULT_SUCCESS != result && ${ X} _RESULT_ERROR_ADAPTER_SPECIFIC != result )
269302 return result;
0 commit comments