Skip to content

Commit 737f344

Browse files
Ewan CrawfordRossBrunton
andcommitted
Add handle translation for entry-point struct list
To workaround #2671 where the spec generator cannot translate handles inside a list of structs, add a special case to the Mako file so that we can handle the new entry-point. This is a temporary measure until the work refactoring handle translation is complete. Co-authored-by: Ross Brunton <[email protected]>
1 parent f55f193 commit 737f344

File tree

2 files changed

+56
-25
lines changed

2 files changed

+56
-25
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ namespace ur_loader
212212
<% handle_structs = th.get_object_handle_structs_to_convert(n, tags, obj, meta) %>
213213
%if handle_structs:
214214
// Deal with any struct parameters that have handle members we need to convert.
215+
%if func_basename == "CommandBufferUpdateKernelLaunchExp":
216+
## CommandBufferUpdateKernelLaunchExp entry-point takes a list of structs with
217+
## handle members, as well as members defining a nested list of structs
218+
## containing handles. This useage is not supported yet, so special case as
219+
## a temporary measure.
220+
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> pUpdateKernelLaunchVector = {};
221+
std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
222+
ppUpdateKernelLaunchpNewMemObjArgList(numKernelUpdates);
223+
for (size_t Offset = 0; Offset < numKernelUpdates; Offset ++) {
224+
%endif
215225
%for struct in handle_structs:
216226
%if struct['optional']:
217227
${struct['type']} ${struct['name']}Local = {};
@@ -239,7 +249,13 @@ namespace ur_loader
239249
range_end = member['range_end']
240250
if not re.match(r"[0-9]+$", range_end):
241251
range_end = struct['name'] + "->" + member['parent'] + range_end %>
252+
253+
%if func_basename == "CommandBufferUpdateKernelLaunchExp":
254+
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>&
255+
pUpdateKernelLaunchpNewMemObjArgList = ppUpdateKernelLaunchpNewMemObjArgList[Offset];
256+
%else:
242257
std::vector<${member['type']}> ${range_vector_name};
258+
%endif
243259
for(uint32_t i = ${range_start}; i < ${range_end}; i++) {
244260
${member['type']} NewRangeStruct = ${struct['name']}Local.${member['parent']}${member['name']}[i];
245261
%for handle_member in member['handle_members']:
@@ -277,6 +293,12 @@ namespace ur_loader
277293
%endfor
278294
%endfor
279295
296+
%if func_basename == "CommandBufferUpdateKernelLaunchExp":
297+
pUpdateKernelLaunchVector.push_back(pUpdateKernelLaunchLocal);
298+
pUpdateKernelLaunch++;
299+
}
300+
pUpdateKernelLaunch = pUpdateKernelLaunchVector.data();
301+
%else:
280302
// Now that we've converted all the members update the param pointers
281303
%for struct in handle_structs:
282304
%if struct['optional']:
@@ -285,6 +307,7 @@ namespace ur_loader
285307
${struct['name']} = &${struct['name']}Local;
286308
%endfor
287309
%endif
310+
%endif
288311
289312
// forward to device-platform
290313
%if add_local:

source/loader/ur_ldrddi.cpp

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8411,35 +8411,43 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
84118411

84128412
// Deal with any struct parameters that have handle members we need to
84138413
// convert.
8414-
auto pUpdateKernelLaunchLocal = *pUpdateKernelLaunch;
8415-
8416-
pUpdateKernelLaunchLocal.hCommand =
8417-
reinterpret_cast<ur_exp_command_buffer_command_object_t *>(
8418-
pUpdateKernelLaunchLocal.hCommand)
8419-
->handle;
8420-
if (pUpdateKernelLaunchLocal.hNewKernel)
8421-
pUpdateKernelLaunchLocal.hNewKernel =
8422-
reinterpret_cast<ur_kernel_object_t *>(
8423-
pUpdateKernelLaunchLocal.hNewKernel)
8414+
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t>
8415+
pUpdateKernelLaunchVector = {};
8416+
std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
8417+
ppUpdateKernelLaunchpNewMemObjArgList(numKernelUpdates);
8418+
for (size_t Offset = 0; Offset < numKernelUpdates; Offset++) {
8419+
auto pUpdateKernelLaunchLocal = *pUpdateKernelLaunch;
8420+
8421+
pUpdateKernelLaunchLocal.hCommand =
8422+
reinterpret_cast<ur_exp_command_buffer_command_object_t *>(
8423+
pUpdateKernelLaunchLocal.hCommand)
84248424
->handle;
8425-
8426-
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>
8427-
pUpdateKernelLaunchpNewMemObjArgList;
8428-
for (uint32_t i = 0; i < pUpdateKernelLaunch->numNewMemObjArgs; i++) {
8429-
ur_exp_command_buffer_update_memobj_arg_desc_t NewRangeStruct =
8430-
pUpdateKernelLaunchLocal.pNewMemObjArgList[i];
8431-
if (NewRangeStruct.hNewMemObjArg)
8432-
NewRangeStruct.hNewMemObjArg =
8433-
reinterpret_cast<ur_mem_object_t *>(NewRangeStruct.hNewMemObjArg)
8425+
if (pUpdateKernelLaunchLocal.hNewKernel)
8426+
pUpdateKernelLaunchLocal.hNewKernel =
8427+
reinterpret_cast<ur_kernel_object_t *>(
8428+
pUpdateKernelLaunchLocal.hNewKernel)
84348429
->handle;
84358430

8436-
pUpdateKernelLaunchpNewMemObjArgList.push_back(NewRangeStruct);
8437-
}
8438-
pUpdateKernelLaunchLocal.pNewMemObjArgList =
8439-
pUpdateKernelLaunchpNewMemObjArgList.data();
8431+
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>
8432+
&pUpdateKernelLaunchpNewMemObjArgList =
8433+
ppUpdateKernelLaunchpNewMemObjArgList[Offset];
8434+
for (uint32_t i = 0; i < pUpdateKernelLaunch->numNewMemObjArgs; i++) {
8435+
ur_exp_command_buffer_update_memobj_arg_desc_t NewRangeStruct =
8436+
pUpdateKernelLaunchLocal.pNewMemObjArgList[i];
8437+
if (NewRangeStruct.hNewMemObjArg)
8438+
NewRangeStruct.hNewMemObjArg =
8439+
reinterpret_cast<ur_mem_object_t *>(NewRangeStruct.hNewMemObjArg)
8440+
->handle;
8441+
8442+
pUpdateKernelLaunchpNewMemObjArgList.push_back(NewRangeStruct);
8443+
}
8444+
pUpdateKernelLaunchLocal.pNewMemObjArgList =
8445+
pUpdateKernelLaunchpNewMemObjArgList.data();
84408446

8441-
// Now that we've converted all the members update the param pointers
8442-
pUpdateKernelLaunch = &pUpdateKernelLaunchLocal;
8447+
pUpdateKernelLaunchVector.push_back(pUpdateKernelLaunchLocal);
8448+
pUpdateKernelLaunch++;
8449+
}
8450+
pUpdateKernelLaunch = pUpdateKernelLaunchVector.data();
84438451

84448452
// forward to device-platform
84458453
result = pfnUpdateKernelLaunchExp(hCommandBuffer, numKernelUpdates,

0 commit comments

Comments
 (0)