@@ -24,13 +24,37 @@ from templates import helper as th
2424namespace ur_loader
2525{
2626 % for obj in th.get_adapter_functions(specs):
27+ <%
28+ func_name = th.make_func_name(n, tags, obj)
29+ if func_name.startswith(x):
30+ func_basename = func_name[len (x):]
31+ else :
32+ func_basename = func_name
33+ %>
34+ % if func_basename == " EventSetCallback" :
35+ namespace {
36+ struct event_callback_wrapper_data_t {
37+ ${ x} _event_callback_t fn;
38+ ${ x} _event_handle_t event;
39+ void *userData;
40+ };
41+
42+ void event_callback_wrapper([[maybe_unused]] ${ x} _event_handle_t hEvent,
43+ ${ x} _execution_info_t execStatus, void *pUserData) {
44+ auto *wrapper =
45+ reinterpret_cast<event _callback_wrapper_data_t * >(pUserData);
46+ (wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
47+ delete wrapper;
48+ }
49+ }
50+
51+ %endif
2752 ///////////////////////////////////////////////////////////////////////////////
28- /// @brief Intercept function for ${ th.make_func_name(n, tags, obj) }
53+ /// @brief Intercept function for ${ func_name }
2954 % if ' condition' in obj:
3055 #if ${ th.subt(n, tags, obj[' condition' ])}
3156 %endif
32- __${ x} dlllocal ${ x} _result_t ${ X} _APICALL
33- ${ th.make_func_name(n, tags, obj)} (
57+ __${ x} dlllocal ${ x} _result_t ${ X} _APICALL ${ func_name} (
3458 % for line in th.make_param_lines(n, tags, obj):
3559 ${ line}
3660 %endfor
@@ -41,7 +65,7 @@ namespace ur_loader
4165 %> ${ th.get_initial_null_set(obj)}
4266
4367 [[maybe_unused]] auto context = getContext();
44- % if re.match( r " \w + AdapterGet$ " , th.make_func_name(n, tags, obj)) :
68+ % if func_basename == " AdapterGet" :
4569
4670 size_t adapterIndex = 0;
4771 if( nullptr != ${ obj[' params' ][1 ][' name' ]} && ${ obj[' params' ][0 ][' name' ]} !=0)
@@ -74,7 +98,7 @@ namespace ur_loader
7498 *${ obj[' params' ][2 ][' name' ]} = static_cast<uint32 _t >(context->platforms.size());
7599 }
76100
77- % elif re.match( r " \w + PlatformGet$ " , th.make_func_name(n, tags, obj)) :
101+ % elif func_basename == " PlatformGet" :
78102 uint32_t total_platform_handle_count = 0;
79103
80104 for( uint32_t adapter_index = 0; adapter_index < ${ obj[' params' ][1 ][' name' ]} ; adapter_index++)
@@ -132,6 +156,16 @@ namespace ur_loader
132156 <%break %>
133157 %endif
134158 %endfor
159+ % if func_basename == " EventSetCallback" :
160+
161+ // Replace the callback with a wrapper function that gives the callback the loader event rather than a
162+ // backend-specific event
163+ auto *wrapper_data =
164+ new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
165+ pUserData = wrapper_data;
166+ pfnNotify = event_callback_wrapper;
167+
168+ %endif
135169 % for i, item in enumerate (th.get_loader_prologue(n, tags, obj, meta)):
136170 % if ' range' in item:
137171 <%
@@ -263,7 +297,7 @@ namespace ur_loader
263297 % for i, item in enumerate (epilogue):
264298 % if 0 == i and not item[' release' ] and not item[' retain' ] and not th.always_wrap_outputs(obj):
265299 ## TODO: Remove once we have a concrete way for submitting warnings in place.
266- % if re.match(r " urEnqueue \w + " , th.make_func_name(n, tags, obj) ):
300+ % if re.match(r " Enqueue \w + " , func_basename ):
267301 // In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
268302 if( ${ X} _RESULT_SUCCESS != result && ${ X} _RESULT_ERROR_ADAPTER_SPECIFIC != result )
269303 return result;
0 commit comments