Skip to content

Commit 4a7bf7b

Browse files
authored
Merge pull request #1421 from ldrumm/find-stype
Introduce a utility to walk polymorphic linked lists
2 parents a6b80d8 + 5bd6d94 commit 4a7bf7b

File tree

5 files changed

+182
-14
lines changed

5 files changed

+182
-14
lines changed

scripts/generate_code.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,25 @@ def generate_layers(path, section, namespace, tags, version, specs, meta):
411411
generates common utilities for unified_runtime
412412
"""
413413
def generate_common(path, section, namespace, tags, version, specs, meta):
414+
template = "stype_map_helpers.hpp.mako"
415+
fin = os.path.join("templates", template)
416+
417+
filename = "stype_map_helpers.def"
414418
layer_dstpath = os.path.join(path, "common")
415419
os.makedirs(layer_dstpath, exist_ok=True)
420+
fout = os.path.join(layer_dstpath, filename)
421+
422+
print("Generating %s..." % fout)
423+
424+
loc = util.makoWrite(
425+
fin, fout,
426+
ver=version,
427+
namespace=namespace,
428+
tags=tags,
429+
specs=specs,
430+
meta=meta)
431+
print("COMMON Generated %s lines of code.\n" % loc)
416432

417-
loc = 0
418-
print("COMMON Generated %s lines of code.\n"%loc)
419433

420434
"""
421435
Entry-point:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<%!
2+
import re
3+
from templates import helper as th
4+
%><%
5+
n=namespace
6+
N=n.upper()
7+
x=tags['$x']
8+
X=x.upper()
9+
%>
10+
// This file is autogenerated from the template at ${self.template.filename}
11+
12+
%for obj in th.extract_objs(specs, r"enum"):
13+
%if obj["name"] == '$x_structure_type_t':
14+
%for etor in obj['etors']:
15+
%if 'UINT32' not in etor['name']:
16+
template <>
17+
struct stype_map<${x}_${etor['desc'][3:]}> : stype_map_impl<${X}_${etor['name'][3:]}> {};
18+
%endif
19+
%endfor
20+
%endif
21+
%endfor
22+

source/adapters/hip/usm.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,24 +327,15 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
327327
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
328328
ur_usm_pool_desc_t *PoolDesc)
329329
: Context(Context) {
330-
const void *pNext = PoolDesc->pNext;
331-
while (pNext != nullptr) {
332-
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
333-
switch (BaseDesc->stype) {
334-
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
335-
const ur_usm_pool_limits_desc_t *Limits =
336-
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
330+
if (PoolDesc) {
331+
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {
337332
for (auto &config : DisjointPoolConfigs.Configs) {
338333
config.MaxPoolableSize = Limits->maxPoolableSize;
339334
config.SlabMinSize = Limits->minDriverAllocSize;
340335
}
341-
break;
342-
}
343-
default: {
336+
} else {
344337
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
345338
}
346-
}
347-
pNext = BaseDesc->pNext;
348339
}
349340

350341
auto MemProvider =

source/common/stype_map_helpers.def

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
2+
// This file is autogenerated from the template at templates/stype_map_helpers.hpp.mako
3+
4+
template <>
5+
struct stype_map<ur_context_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_CONTEXT_PROPERTIES> {};
6+
template <>
7+
struct stype_map<ur_image_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_IMAGE_DESC> {};
8+
template <>
9+
struct stype_map<ur_buffer_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_PROPERTIES> {};
10+
template <>
11+
struct stype_map<ur_buffer_region_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_REGION> {};
12+
template <>
13+
struct stype_map<ur_buffer_channel_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_CHANNEL_PROPERTIES> {};
14+
template <>
15+
struct stype_map<ur_buffer_alloc_location_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_ALLOC_LOCATION_PROPERTIES> {};
16+
template <>
17+
struct stype_map<ur_program_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PROGRAM_PROPERTIES> {};
18+
template <>
19+
struct stype_map<ur_usm_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_DESC> {};
20+
template <>
21+
struct stype_map<ur_usm_host_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_HOST_DESC> {};
22+
template <>
23+
struct stype_map<ur_usm_device_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_DEVICE_DESC> {};
24+
template <>
25+
struct stype_map<ur_usm_pool_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_POOL_DESC> {};
26+
template <>
27+
struct stype_map<ur_usm_pool_limits_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC> {};
28+
template <>
29+
struct stype_map<ur_device_binary_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_BINARY> {};
30+
template <>
31+
struct stype_map<ur_sampler_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_SAMPLER_DESC> {};
32+
template <>
33+
struct stype_map<ur_queue_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_PROPERTIES> {};
34+
template <>
35+
struct stype_map<ur_queue_index_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_INDEX_PROPERTIES> {};
36+
template <>
37+
struct stype_map<ur_context_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_CONTEXT_NATIVE_PROPERTIES> {};
38+
template <>
39+
struct stype_map<ur_kernel_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES> {};
40+
template <>
41+
struct stype_map<ur_queue_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_NATIVE_PROPERTIES> {};
42+
template <>
43+
struct stype_map<ur_mem_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_MEM_NATIVE_PROPERTIES> {};
44+
template <>
45+
struct stype_map<ur_event_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES> {};
46+
template <>
47+
struct stype_map<ur_platform_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PLATFORM_NATIVE_PROPERTIES> {};
48+
template <>
49+
struct stype_map<ur_device_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES> {};
50+
template <>
51+
struct stype_map<ur_program_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES> {};
52+
template <>
53+
struct stype_map<ur_sampler_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_SAMPLER_NATIVE_PROPERTIES> {};
54+
template <>
55+
struct stype_map<ur_queue_native_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_NATIVE_DESC> {};
56+
template <>
57+
struct stype_map<ur_device_partition_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES> {};
58+
template <>
59+
struct stype_map<ur_kernel_arg_mem_obj_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES> {};
60+
template <>
61+
struct stype_map<ur_physical_mem_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES> {};
62+
template <>
63+
struct stype_map<ur_kernel_arg_pointer_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_POINTER_PROPERTIES> {};
64+
template <>
65+
struct stype_map<ur_kernel_arg_sampler_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_SAMPLER_PROPERTIES> {};
66+
template <>
67+
struct stype_map<ur_kernel_exec_info_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_EXEC_INFO_PROPERTIES> {};
68+
template <>
69+
struct stype_map<ur_kernel_arg_value_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_VALUE_PROPERTIES> {};
70+
template <>
71+
struct stype_map<ur_kernel_arg_local_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_LOCAL_PROPERTIES> {};
72+
template <>
73+
struct stype_map<ur_usm_alloc_location_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC> {};
74+
template <>
75+
struct stype_map<ur_exp_command_buffer_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC> {};
76+
template <>
77+
struct stype_map<ur_exp_command_buffer_update_kernel_launch_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC> {};
78+
template <>
79+
struct stype_map<ur_exp_command_buffer_update_memobj_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC> {};
80+
template <>
81+
struct stype_map<ur_exp_command_buffer_update_pointer_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC> {};
82+
template <>
83+
struct stype_map<ur_exp_command_buffer_update_value_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC> {};
84+
template <>
85+
struct stype_map<ur_exp_sampler_mip_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES> {};
86+
template <>
87+
struct stype_map<ur_exp_interop_mem_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_INTEROP_MEM_DESC> {};
88+
template <>
89+
struct stype_map<ur_exp_interop_semaphore_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_INTEROP_SEMAPHORE_DESC> {};
90+
template <>
91+
struct stype_map<ur_exp_file_descriptor_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_FILE_DESCRIPTOR> {};
92+
template <>
93+
struct stype_map<ur_exp_win32_handle_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_WIN32_HANDLE> {};
94+
template <>
95+
struct stype_map<ur_exp_sampler_addr_modes_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES> {};
96+
template <>
97+
struct stype_map<ur_exp_sampler_cubemap_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_CUBEMAP_PROPERTIES> {};
98+

source/common/ur_util.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,49 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) {
281281

282282
template <class> inline constexpr bool ur_always_false_t = false;
283283

284+
namespace {
285+
// Compile-time map, mapping a UR list node type, to the enum tag type
286+
// These are helpers for the `find_stype_node` helper below
287+
template <ur_structure_type_t val> struct stype_map_impl {
288+
static constexpr ur_structure_type_t value = val;
289+
};
290+
291+
template <typename T> struct stype_map {};
292+
// contains definitions of the map specializations e.g.
293+
// template <> struct stype_map<ur_usm_device_desc_t> :
294+
// stype_map_impl<UR_STRUCTURE_TYPE_USM_DEVICE_DESC> {};
295+
#include "stype_map_helpers.def"
296+
297+
template <typename T> constexpr int as_stype() { return stype_map<T>::value; };
298+
299+
/// Walk a generic UR linked list looking for a node of the given type. If it's
300+
/// found, its address is returned, othewise `nullptr`. e.g. to find out whether
301+
/// a `ur_usm_host_desc_t` exists in the given polymorphic list, `mylist`:
302+
///
303+
/// ```cpp
304+
/// auto *node = find_stype_node<ur_usm_host_desc_t>(mylist);
305+
/// if (!node)
306+
/// printf("node of expected type not found!\n");
307+
/// ```
308+
template <typename T, typename P>
309+
typename std::conditional_t<std::is_const_v<std::remove_pointer_t<P>>,
310+
const T *, T *>
311+
find_stype_node(P list_head) noexcept {
312+
auto *list = reinterpret_cast<const T *>(list_head);
313+
for (const auto *next = reinterpret_cast<const T *>(list); next;
314+
next = reinterpret_cast<const T *>(next->pNext)) {
315+
if (next->stype == as_stype<T>()) {
316+
if constexpr (!std::is_const_v<P>) {
317+
return const_cast<T *>(next);
318+
} else {
319+
return next;
320+
}
321+
}
322+
}
323+
return nullptr;
324+
}
325+
} // namespace
326+
284327
namespace ur {
285328
[[noreturn]] inline void unreachable() {
286329
#ifdef _MSC_VER

0 commit comments

Comments
 (0)