Skip to content

Commit 5bd6d94

Browse files
ldrummkbenzie
authored andcommitted
Introduce a utility to walk polymorphic linked lists
Add `find_stype_node` for walking polymorphic linked lists looking for a particular type. The implementation here works with an auto-generated compile time map, that links a given type to the structure enumeration tag. The implementation simply walks the list looking for the .stype implied by the template parameter type and casts it to the expected type, then returning it. It's one of those unfortunate cases where you can write pretty nasty implementation code that makes the user code much much nicer. In this case the user doesn't need to worry about the `.types` at all, and the const-void casts can be eliminated from user code.
1 parent 3eda0d5 commit 5bd6d94

File tree

5 files changed

+182
-14
lines changed

5 files changed

+182
-14
lines changed

scripts/generate_code.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,25 @@ def generate_layers(path, section, namespace, tags, version, specs, meta):
411411
generates common utilities for unified_runtime
412412
"""
413413
def generate_common(path, section, namespace, tags, version, specs, meta):
414+
template = "stype_map_helpers.hpp.mako"
415+
fin = os.path.join("templates", template)
416+
417+
filename = "stype_map_helpers.def"
414418
layer_dstpath = os.path.join(path, "common")
415419
os.makedirs(layer_dstpath, exist_ok=True)
420+
fout = os.path.join(layer_dstpath, filename)
421+
422+
print("Generating %s..." % fout)
423+
424+
loc = util.makoWrite(
425+
fin, fout,
426+
ver=version,
427+
namespace=namespace,
428+
tags=tags,
429+
specs=specs,
430+
meta=meta)
431+
print("COMMON Generated %s lines of code.\n" % loc)
416432

417-
loc = 0
418-
print("COMMON Generated %s lines of code.\n"%loc)
419433

420434
"""
421435
Entry-point:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<%!
2+
import re
3+
from templates import helper as th
4+
%><%
5+
n=namespace
6+
N=n.upper()
7+
x=tags['$x']
8+
X=x.upper()
9+
%>
10+
// This file is autogenerated from the template at ${self.template.filename}
11+
12+
%for obj in th.extract_objs(specs, r"enum"):
13+
%if obj["name"] == '$x_structure_type_t':
14+
%for etor in obj['etors']:
15+
%if 'UINT32' not in etor['name']:
16+
template <>
17+
struct stype_map<${x}_${etor['desc'][3:]}> : stype_map_impl<${X}_${etor['name'][3:]}> {};
18+
%endif
19+
%endfor
20+
%endif
21+
%endfor
22+

source/adapters/hip/usm.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,24 +327,15 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
327327
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
328328
ur_usm_pool_desc_t *PoolDesc)
329329
: Context(Context) {
330-
const void *pNext = PoolDesc->pNext;
331-
while (pNext != nullptr) {
332-
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
333-
switch (BaseDesc->stype) {
334-
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
335-
const ur_usm_pool_limits_desc_t *Limits =
336-
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
330+
if (PoolDesc) {
331+
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {
337332
for (auto &config : DisjointPoolConfigs.Configs) {
338333
config.MaxPoolableSize = Limits->maxPoolableSize;
339334
config.SlabMinSize = Limits->minDriverAllocSize;
340335
}
341-
break;
342-
}
343-
default: {
336+
} else {
344337
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
345338
}
346-
}
347-
pNext = BaseDesc->pNext;
348339
}
349340

350341
auto MemProvider =

source/common/stype_map_helpers.def

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
2+
// This file is autogenerated from the template at templates/stype_map_helpers.hpp.mako
3+
4+
template <>
5+
struct stype_map<ur_context_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_CONTEXT_PROPERTIES> {};
6+
template <>
7+
struct stype_map<ur_image_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_IMAGE_DESC> {};
8+
template <>
9+
struct stype_map<ur_buffer_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_PROPERTIES> {};
10+
template <>
11+
struct stype_map<ur_buffer_region_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_REGION> {};
12+
template <>
13+
struct stype_map<ur_buffer_channel_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_CHANNEL_PROPERTIES> {};
14+
template <>
15+
struct stype_map<ur_buffer_alloc_location_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_ALLOC_LOCATION_PROPERTIES> {};
16+
template <>
17+
struct stype_map<ur_program_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PROGRAM_PROPERTIES> {};
18+
template <>
19+
struct stype_map<ur_usm_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_DESC> {};
20+
template <>
21+
struct stype_map<ur_usm_host_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_HOST_DESC> {};
22+
template <>
23+
struct stype_map<ur_usm_device_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_DEVICE_DESC> {};
24+
template <>
25+
struct stype_map<ur_usm_pool_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_POOL_DESC> {};
26+
template <>
27+
struct stype_map<ur_usm_pool_limits_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC> {};
28+
template <>
29+
struct stype_map<ur_device_binary_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_BINARY> {};
30+
template <>
31+
struct stype_map<ur_sampler_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_SAMPLER_DESC> {};
32+
template <>
33+
struct stype_map<ur_queue_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_PROPERTIES> {};
34+
template <>
35+
struct stype_map<ur_queue_index_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_INDEX_PROPERTIES> {};
36+
template <>
37+
struct stype_map<ur_context_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_CONTEXT_NATIVE_PROPERTIES> {};
38+
template <>
39+
struct stype_map<ur_kernel_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES> {};
40+
template <>
41+
struct stype_map<ur_queue_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_NATIVE_PROPERTIES> {};
42+
template <>
43+
struct stype_map<ur_mem_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_MEM_NATIVE_PROPERTIES> {};
44+
template <>
45+
struct stype_map<ur_event_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES> {};
46+
template <>
47+
struct stype_map<ur_platform_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PLATFORM_NATIVE_PROPERTIES> {};
48+
template <>
49+
struct stype_map<ur_device_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES> {};
50+
template <>
51+
struct stype_map<ur_program_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES> {};
52+
template <>
53+
struct stype_map<ur_sampler_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_SAMPLER_NATIVE_PROPERTIES> {};
54+
template <>
55+
struct stype_map<ur_queue_native_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_NATIVE_DESC> {};
56+
template <>
57+
struct stype_map<ur_device_partition_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES> {};
58+
template <>
59+
struct stype_map<ur_kernel_arg_mem_obj_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES> {};
60+
template <>
61+
struct stype_map<ur_physical_mem_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES> {};
62+
template <>
63+
struct stype_map<ur_kernel_arg_pointer_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_POINTER_PROPERTIES> {};
64+
template <>
65+
struct stype_map<ur_kernel_arg_sampler_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_SAMPLER_PROPERTIES> {};
66+
template <>
67+
struct stype_map<ur_kernel_exec_info_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_EXEC_INFO_PROPERTIES> {};
68+
template <>
69+
struct stype_map<ur_kernel_arg_value_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_VALUE_PROPERTIES> {};
70+
template <>
71+
struct stype_map<ur_kernel_arg_local_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_LOCAL_PROPERTIES> {};
72+
template <>
73+
struct stype_map<ur_usm_alloc_location_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC> {};
74+
template <>
75+
struct stype_map<ur_exp_command_buffer_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC> {};
76+
template <>
77+
struct stype_map<ur_exp_command_buffer_update_kernel_launch_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC> {};
78+
template <>
79+
struct stype_map<ur_exp_command_buffer_update_memobj_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC> {};
80+
template <>
81+
struct stype_map<ur_exp_command_buffer_update_pointer_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC> {};
82+
template <>
83+
struct stype_map<ur_exp_command_buffer_update_value_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC> {};
84+
template <>
85+
struct stype_map<ur_exp_sampler_mip_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES> {};
86+
template <>
87+
struct stype_map<ur_exp_interop_mem_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_INTEROP_MEM_DESC> {};
88+
template <>
89+
struct stype_map<ur_exp_interop_semaphore_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_INTEROP_SEMAPHORE_DESC> {};
90+
template <>
91+
struct stype_map<ur_exp_file_descriptor_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_FILE_DESCRIPTOR> {};
92+
template <>
93+
struct stype_map<ur_exp_win32_handle_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_WIN32_HANDLE> {};
94+
template <>
95+
struct stype_map<ur_exp_sampler_addr_modes_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES> {};
96+
template <>
97+
struct stype_map<ur_exp_sampler_cubemap_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_CUBEMAP_PROPERTIES> {};
98+

source/common/ur_util.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,49 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) {
281281

282282
template <class> inline constexpr bool ur_always_false_t = false;
283283

284+
namespace {
285+
// Compile-time map, mapping a UR list node type, to the enum tag type
286+
// These are helpers for the `find_stype_node` helper below
287+
template <ur_structure_type_t val> struct stype_map_impl {
288+
static constexpr ur_structure_type_t value = val;
289+
};
290+
291+
template <typename T> struct stype_map {};
292+
// contains definitions of the map specializations e.g.
293+
// template <> struct stype_map<ur_usm_device_desc_t> :
294+
// stype_map_impl<UR_STRUCTURE_TYPE_USM_DEVICE_DESC> {};
295+
#include "stype_map_helpers.def"
296+
297+
template <typename T> constexpr int as_stype() { return stype_map<T>::value; };
298+
299+
/// Walk a generic UR linked list looking for a node of the given type. If it's
300+
/// found, its address is returned, othewise `nullptr`. e.g. to find out whether
301+
/// a `ur_usm_host_desc_t` exists in the given polymorphic list, `mylist`:
302+
///
303+
/// ```cpp
304+
/// auto *node = find_stype_node<ur_usm_host_desc_t>(mylist);
305+
/// if (!node)
306+
/// printf("node of expected type not found!\n");
307+
/// ```
308+
template <typename T, typename P>
309+
typename std::conditional_t<std::is_const_v<std::remove_pointer_t<P>>,
310+
const T *, T *>
311+
find_stype_node(P list_head) noexcept {
312+
auto *list = reinterpret_cast<const T *>(list_head);
313+
for (const auto *next = reinterpret_cast<const T *>(list); next;
314+
next = reinterpret_cast<const T *>(next->pNext)) {
315+
if (next->stype == as_stype<T>()) {
316+
if constexpr (!std::is_const_v<P>) {
317+
return const_cast<T *>(next);
318+
} else {
319+
return next;
320+
}
321+
}
322+
}
323+
return nullptr;
324+
}
325+
} // namespace
326+
284327
namespace ur {
285328
[[noreturn]] inline void unreachable() {
286329
#ifdef _MSC_VER

0 commit comments

Comments
 (0)