Skip to content

Commit 2a82754

Browse files
authored
Merge pull request #1116 from pbalcer/fix-loader-getinfo-handles
[loader] perform handle conversion after info queries
2 parents b05c5b5 + 81c8b1b commit 2a82754

File tree

14 files changed

+931
-55
lines changed

14 files changed

+931
-55
lines changed

include/ur_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5103,7 +5103,7 @@ urKernelCreateWithNativeHandle(
51035103
///////////////////////////////////////////////////////////////////////////////
51045104
/// @brief Query queue info
51055105
typedef enum ur_queue_info_t {
5106-
UR_QUEUE_INFO_CONTEXT = 0, ///< [::ur_queue_handle_t] context associated with this queue.
5106+
UR_QUEUE_INFO_CONTEXT = 0, ///< [::ur_context_handle_t] context associated with this queue.
51075107
UR_QUEUE_INFO_DEVICE = 1, ///< [::ur_device_handle_t] device associated with this queue.
51085108
UR_QUEUE_INFO_DEVICE_DEFAULT = 2, ///< [::ur_queue_handle_t] the current default queue of the underlying
51095109
///< device.

include/ur_print.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8018,9 +8018,9 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_queue_info_
80188018

80198019
switch (value) {
80208020
case UR_QUEUE_INFO_CONTEXT: {
8021-
const ur_queue_handle_t *tptr = (const ur_queue_handle_t *)ptr;
8022-
if (sizeof(ur_queue_handle_t) > size) {
8023-
os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_queue_handle_t) << ")";
8021+
const ur_context_handle_t *tptr = (const ur_context_handle_t *)ptr;
8022+
if (sizeof(ur_context_handle_t) > size) {
8023+
os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_context_handle_t) << ")";
80248024
return UR_RESULT_ERROR_INVALID_SIZE;
80258025
}
80268026
os << (const void *)(tptr) << " (";

scripts/core/queue.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ name: $x_queue_info_t
1919
typed_etors: True
2020
etors:
2121
- name: CONTEXT
22-
desc: "[$x_queue_handle_t] context associated with this queue."
22+
desc: "[$x_context_handle_t] context associated with this queue."
2323
- name: DEVICE
2424
desc: "[$x_device_handle_t] device associated with this queue."
2525
- name: DEVICE_DEFAULT

scripts/templates/helper.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def is_handle(obj):
3939
except:
4040
return False
4141

42+
@staticmethod
43+
def is_enum(obj):
44+
try:
45+
return True if re.match(r"enum", obj['type']) else False
46+
except:
47+
return False
48+
4249
@staticmethod
4350
def is_experimental(obj):
4451
try:
@@ -449,6 +456,13 @@ def is_release(cls, item):
449456
except:
450457
return False
451458

459+
@classmethod
460+
def is_typename(cls, item):
461+
try:
462+
return True if re.match(cls.RE_TYPENAME, item['desc']) else False
463+
except:
464+
return False
465+
452466
@classmethod
453467
def typename(cls, item):
454468
match = re.match(cls.RE_TYPENAME, item['desc'])
@@ -1241,24 +1255,43 @@ def get_loader_prologue(namespace, tags, obj, meta):
12411255

12421256
return prologue
12431257

1258+
"""
1259+
Public:
1260+
returns an enum object with the given name
1261+
"""
1262+
def get_enum_by_name(specs, namespace, tags, name, only_typed):
1263+
for s in specs:
1264+
for obj in s['objects']:
1265+
if obj_traits.is_enum(obj) and make_enum_name(namespace, tags, obj) == name:
1266+
typed = obj.get('typed_etors', False) is True
1267+
if only_typed:
1268+
if typed:
1269+
return obj
1270+
else:
1271+
return None
1272+
else:
1273+
return obj
1274+
return None
1275+
12441276
"""
12451277
Public:
12461278
returns a list of dict for converting loader output parameters
12471279
"""
1248-
def get_loader_epilogue(namespace, tags, obj, meta):
1280+
def get_loader_epilogue(specs, namespace, tags, obj, meta):
12491281
epilogue = []
12501282

12511283
for i, item in enumerate(obj['params']):
12521284
if param_traits.is_mbz(item):
12531285
continue
1254-
if param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item):
1255-
if type_traits.is_class_handle(item['type'], meta):
1256-
name = subt(namespace, tags, item['name'])
1257-
tname = _remove_const_ptr(subt(namespace, tags, item['type']))
12581286

1259-
obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname)
1260-
fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname)
1287+
name = subt(namespace, tags, item['name'])
1288+
tname = _remove_const_ptr(subt(namespace, tags, item['type']))
1289+
1290+
obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname)
1291+
fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname)
12611292

1293+
if param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item):
1294+
if type_traits.is_class_handle(item['type'], meta):
12621295
if param_traits.is_range(item):
12631296
range_start = param_traits.range_start(item)
12641297
range_end = param_traits.range_end(item)
@@ -1279,6 +1312,44 @@ def get_loader_epilogue(namespace, tags, obj, meta):
12791312
'release': param_traits.is_release(item),
12801313
'optional': param_traits.is_optional(item)
12811314
})
1315+
elif param_traits.is_typename(item):
1316+
typename = param_traits.typename(item)
1317+
underlying_type = None
1318+
for inner in obj['params']:
1319+
iname = _get_param_name(namespace, tags, inner)
1320+
if iname == typename:
1321+
underlying_type = _get_type_name(namespace, tags, obj, inner)
1322+
if underlying_type is None:
1323+
continue
1324+
1325+
prop_size = param_traits.typename_size(item)
1326+
enum = get_enum_by_name(specs, namespace, tags, underlying_type, True)
1327+
handle_etors = []
1328+
for etor in enum['etors']:
1329+
associated_type = etor_get_associated_type(namespace, tags, etor)
1330+
if 'handle' in associated_type:
1331+
is_array = False
1332+
if value_traits.is_array(associated_type):
1333+
associated_type = value_traits.get_array_name(associated_type)
1334+
is_array = True
1335+
1336+
etor_name = make_etor_name(namespace, tags, enum['name'], etor['name'])
1337+
obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", associated_type)
1338+
fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", associated_type)
1339+
handle_etors.append({'name': etor_name,
1340+
'type': associated_type,
1341+
'obj': obj_name,
1342+
'factory': fty_name,
1343+
'is_array': is_array})
1344+
1345+
if handle_etors:
1346+
epilogue.append({
1347+
'name': name,
1348+
'obj': obj_name,
1349+
'release': False,
1350+
'typename': typename,
1351+
'size': prop_size,
1352+
'etors': handle_etors})
12821353

12831354
return epilogue
12841355

scripts/templates/ldrddi.cpp.mako

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,23 @@ namespace ur_loader
160160
%endif
161161
162162
%endfor
163+
164+
<%
165+
epilogue = th.get_loader_epilogue(specs, n, tags, obj, meta)
166+
has_typename = False
167+
for item in epilogue:
168+
if 'typename' in item:
169+
has_typename = True
170+
break
171+
%>
172+
173+
%if has_typename:
174+
// this value is needed for converting adapter handles to loader handles
175+
size_t sizeret = 0;
176+
if (pPropSizeRet == NULL)
177+
pPropSizeRet = &sizeret;
178+
%endif
179+
163180
// forward to device-platform
164181
%if add_local:
165182
result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name", "local"], replacements=param_replacements))} );
@@ -168,8 +185,9 @@ namespace ur_loader
168185
%endif
169186
<%
170187
del param_replacements
171-
del add_local%>
172-
%for i, item in enumerate(th.get_loader_epilogue(n, tags, obj, meta)):
188+
del add_local
189+
%>
190+
%for i, item in enumerate(epilogue):
173191
%if 0 == i:
174192
if( ${X}_RESULT_SUCCESS != result )
175193
return result;
@@ -181,7 +199,25 @@ namespace ur_loader
181199
%elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
182200
try
183201
{
184-
%if 'range' in item:
202+
%if 'typename' in item:
203+
if (${item['name']} != nullptr) {
204+
switch (${item['typename']}) {
205+
%for etor in item['etors']:
206+
case ${etor['name']}: {
207+
${etor['type']} *handles = reinterpret_cast<${etor['type']} *>(${item['name']});
208+
size_t nelements = *pPropSizeRet / sizeof(${etor['type']});
209+
for (size_t i = 0; i < nelements; ++i) {
210+
if (handles[i] != nullptr) {
211+
handles[i] = reinterpret_cast<${etor['type']}>(
212+
${etor['factory']}.getInstance( handles[i], dditable ) );
213+
}
214+
}
215+
} break;
216+
%endfor
217+
default: {} break;
218+
}
219+
}
220+
%elif 'range' in item:
185221
// convert platform handles to loader handles
186222
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
187223
${item['name']}[ i ] = reinterpret_cast<${item['type']}>(

scripts/templates/nullddi.cpp.mako

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,23 @@ namespace driver
4848
else
4949
{
5050
// generic implementation
51-
%for item in th.get_loader_epilogue(n, tags, obj, meta):
52-
%if 'range' in item:
51+
%for item in th.get_loader_epilogue(specs, n, tags, obj, meta):
52+
%if 'typename' in item:
53+
if (${item['name']} != nullptr) {
54+
switch (${item['typename']}) {
55+
%for etor in item['etors']:
56+
case ${etor['name']}: {
57+
${etor['type']} *handles = reinterpret_cast<${etor['type']} *>(${item['name']});
58+
size_t nelements = ${item['size']} / sizeof(${etor['type']});
59+
for (size_t i = 0; i < nelements; ++i) {
60+
handles[i] = reinterpret_cast<${etor['type']}>( d_context.get() );
61+
}
62+
} break;
63+
%endfor
64+
default: {} break;
65+
}
66+
}
67+
%elif 'range' in item:
5368
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
5469
${item['name']}[ i ] = reinterpret_cast<${item['type']}>( d_context.get() );
5570
%elif not item['release']:

source/adapters/null/ur_null.cpp

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ context_t d_context;
1717

1818
//////////////////////////////////////////////////////////////////////////
1919
context_t::context_t() {
20+
platform = get();
2021
//////////////////////////////////////////////////////////////////////////
2122
urDdiTable.Global.pfnAdapterGet = [](uint32_t NumAdapters,
2223
ur_adapter_handle_t *phAdapters,
@@ -28,7 +29,7 @@ context_t::context_t() {
2829
*pNumAdapters = 1;
2930
}
3031
if (nullptr != phAdapters) {
31-
*reinterpret_cast<void **>(phAdapters) = d_context.get();
32+
*reinterpret_cast<void **>(phAdapters) = d_context.platform;
3233
}
3334

3435
return UR_RESULT_SUCCESS;
@@ -48,7 +49,7 @@ context_t::context_t() {
4849
*pNumPlatforms = 1;
4950
}
5051
if (nullptr != phPlatforms) {
51-
*reinterpret_cast<void **>(phPlatforms) = d_context.get();
52+
*reinterpret_cast<void **>(phPlatforms) = d_context.platform;
5253
}
5354
return UR_RESULT_SUCCESS;
5455
};
@@ -120,48 +121,59 @@ context_t::context_t() {
120121
};
121122

122123
//////////////////////////////////////////////////////////////////////////
123-
urDdiTable.Device.pfnGetInfo =
124-
[](ur_device_handle_t, ur_device_info_t infoType, size_t propSize,
125-
void *pDeviceInfo, size_t *pPropSizeRet) {
126-
switch (infoType) {
127-
case UR_DEVICE_INFO_TYPE:
128-
if (pDeviceInfo && propSize != sizeof(ur_device_type_t)) {
129-
return UR_RESULT_ERROR_INVALID_SIZE;
130-
}
124+
urDdiTable.Device.pfnGetInfo = [](ur_device_handle_t,
125+
ur_device_info_t infoType,
126+
size_t propSize, void *pDeviceInfo,
127+
size_t *pPropSizeRet) {
128+
switch (infoType) {
129+
case UR_DEVICE_INFO_TYPE:
130+
if (pDeviceInfo && propSize != sizeof(ur_device_type_t)) {
131+
return UR_RESULT_ERROR_INVALID_SIZE;
132+
}
131133

132-
if (pDeviceInfo != nullptr) {
133-
*reinterpret_cast<ur_device_type_t *>(pDeviceInfo) =
134-
UR_DEVICE_TYPE_GPU;
135-
}
136-
if (pPropSizeRet != nullptr) {
137-
*pPropSizeRet = sizeof(ur_device_type_t);
138-
}
139-
break;
134+
if (pDeviceInfo != nullptr) {
135+
*reinterpret_cast<ur_device_type_t *>(pDeviceInfo) =
136+
UR_DEVICE_TYPE_GPU;
137+
}
138+
if (pPropSizeRet != nullptr) {
139+
*pPropSizeRet = sizeof(ur_device_type_t);
140+
}
141+
break;
140142

141-
case UR_DEVICE_INFO_NAME: {
142-
char deviceName[] = "Null Device";
143-
if (pDeviceInfo && propSize < sizeof(deviceName)) {
144-
return UR_RESULT_ERROR_INVALID_SIZE;
145-
}
146-
if (pDeviceInfo != nullptr) {
143+
case UR_DEVICE_INFO_NAME: {
144+
char deviceName[] = "Null Device";
145+
if (pDeviceInfo && propSize < sizeof(deviceName)) {
146+
return UR_RESULT_ERROR_INVALID_SIZE;
147+
}
148+
if (pDeviceInfo != nullptr) {
147149
#if defined(_WIN32)
148-
strncpy_s(reinterpret_cast<char *>(pDeviceInfo), propSize,
149-
deviceName, sizeof(deviceName));
150+
strncpy_s(reinterpret_cast<char *>(pDeviceInfo), propSize,
151+
deviceName, sizeof(deviceName));
150152
#else
151-
strncpy(reinterpret_cast<char *>(pDeviceInfo), deviceName,
152-
propSize);
153+
strncpy(reinterpret_cast<char *>(pDeviceInfo), deviceName,
154+
propSize);
153155
#endif
154-
}
155-
if (pPropSizeRet != nullptr) {
156-
*pPropSizeRet = sizeof(deviceName);
157-
}
158-
} break;
159-
160-
default:
161-
return UR_RESULT_ERROR_INVALID_ARGUMENT;
162156
}
163-
return UR_RESULT_SUCCESS;
164-
};
157+
if (pPropSizeRet != nullptr) {
158+
*pPropSizeRet = sizeof(deviceName);
159+
}
160+
} break;
161+
case UR_DEVICE_INFO_PLATFORM: {
162+
if (pDeviceInfo && propSize < sizeof(pDeviceInfo)) {
163+
return UR_RESULT_ERROR_INVALID_SIZE;
164+
}
165+
if (pDeviceInfo != nullptr) {
166+
*reinterpret_cast<void **>(pDeviceInfo) = d_context.platform;
167+
}
168+
if (pPropSizeRet != nullptr) {
169+
*pPropSizeRet = sizeof(intptr_t);
170+
}
171+
} break;
172+
default:
173+
return UR_RESULT_ERROR_INVALID_ARGUMENT;
174+
}
175+
return UR_RESULT_SUCCESS;
176+
};
165177

166178
//////////////////////////////////////////////////////////////////////////
167179
urDdiTable.USM.pfnHostAlloc = [](ur_context_handle_t, const ur_usm_desc_t *,

source/adapters/null/ur_null.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* @file ur_null.hpp
1010
*
1111
*/
12+
#include "ur_api.h"
1213
#ifndef UR_ADAPTER_NULL_H
1314
#define UR_ADAPTER_NULL_H 1
1415

@@ -27,6 +28,8 @@ class __urdlllocal context_t {
2728
context_t();
2829
~context_t() = default;
2930

31+
void *platform;
32+
3033
void *get() {
3134
static uint64_t count = 0x80800000;
3235
return reinterpret_cast<void *>(++count);

0 commit comments

Comments
 (0)