Skip to content

Commit 86dcd63

Browse files
alexbatashevaarongreig
authored andcommitted
[UR][Loader] Fix handling of native handles
Native handles are created by adapters and thus are inheritently backend-specific. Loader can not assume anything about these handles, as even nullptr may be a valid value for such a handle. This patch changes two things about native handles: 1) Native handles are no longer wrapped in UR objects 2) Dispatch table is extracted from any other argument of the API function The above is true for all interop APIs except for urPlatformCreateWithNativeHandle, which needs a spec change.
1 parent 8fb890d commit 86dcd63

File tree

2 files changed

+16
-122
lines changed

2 files changed

+16
-122
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,17 @@ namespace ur_loader
130130
%else:
131131
<%param_replacements={}%>
132132
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
133-
%if 0 == i:
133+
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
134134
// extract platform's function pointer table
135135
auto dditable = reinterpret_cast<${item['obj']}*>( ${item['pointer']}${item['name']} )->dditable;
136136
auto ${th.make_pfn_name(n, tags, obj)} = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
137137
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
138138
return ${X}_RESULT_ERROR_UNINITIALIZED;
139139
140+
<%break%>
140141
%endif
142+
%endfor
143+
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
141144
%if 'range' in item:
142145
<%
143146
add_local = True
@@ -146,13 +149,15 @@ namespace ur_loader
146149
for( size_t i = ${item['range'][0]}; i < ${item['range'][1]}; ++i )
147150
${item['name']}Local[ i ] = reinterpret_cast<${item['obj']}*>( ${item['name']}[ i ] )->handle;
148151
%else:
152+
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
149153
// convert loader handle to platform handle
150154
%if item['optional']:
151155
${item['name']} = ( ${item['name']} ) ? reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle : nullptr;
152156
%else:
153157
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
154158
%endif
155159
%endif
160+
%endif
156161
157162
%endfor
158163
// forward to device-platform
@@ -173,7 +178,7 @@ namespace ur_loader
173178
%if item['release']:
174179
// release loader handle
175180
${item['factory']}.release( ${item['name']} );
176-
%else:
181+
%elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
177182
try
178183
{
179184
%if 'range' in item:

source/loader/ur_ldrddi.cpp

Lines changed: 9 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,6 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle(
352352
return result;
353353
}
354354

355-
try {
356-
// convert platform handle to loader handle
357-
*phNativePlatform = reinterpret_cast<ur_native_handle_t>(
358-
ur_native_factory.getInstance(*phNativePlatform, dditable));
359-
} catch (std::bad_alloc &) {
360-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
361-
}
362-
363355
return result;
364356
}
365357

@@ -673,14 +665,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
673665
return result;
674666
}
675667

676-
try {
677-
// convert platform handle to loader handle
678-
*phNativeDevice = reinterpret_cast<ur_native_handle_t>(
679-
ur_native_factory.getInstance(*phNativeDevice, dditable));
680-
} catch (std::bad_alloc &) {
681-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
682-
}
683-
684668
return result;
685669
}
686670

@@ -699,17 +683,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
699683

700684
// extract platform's function pointer table
701685
auto dditable =
702-
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->dditable;
686+
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
703687
auto pfnCreateWithNativeHandle =
704688
dditable->ur.Device.pfnCreateWithNativeHandle;
705689
if (nullptr == pfnCreateWithNativeHandle) {
706690
return UR_RESULT_ERROR_UNINITIALIZED;
707691
}
708692

709-
// convert loader handle to platform handle
710-
hNativeDevice =
711-
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->handle;
712-
713693
// convert loader handle to platform handle
714694
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;
715695

@@ -916,14 +896,6 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
916896
return result;
917897
}
918898

919-
try {
920-
// convert platform handle to loader handle
921-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
922-
ur_native_factory.getInstance(*phNativeContext, dditable));
923-
} catch (std::bad_alloc &) {
924-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
925-
}
926-
927899
return result;
928900
}
929901

@@ -944,17 +916,13 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
944916

945917
// extract platform's function pointer table
946918
auto dditable =
947-
reinterpret_cast<ur_native_object_t *>(hNativeContext)->dditable;
919+
reinterpret_cast<ur_device_object_t *>(*phDevices)->dditable;
948920
auto pfnCreateWithNativeHandle =
949921
dditable->ur.Context.pfnCreateWithNativeHandle;
950922
if (nullptr == pfnCreateWithNativeHandle) {
951923
return UR_RESULT_ERROR_UNINITIALIZED;
952924
}
953925

954-
// convert loader handle to platform handle
955-
hNativeContext =
956-
reinterpret_cast<ur_native_object_t *>(hNativeContext)->handle;
957-
958926
// convert loader handles to platform handles
959927
auto phDevicesLocal = std::vector<ur_device_handle_t>(numDevices);
960928
for (size_t i = 0; i < numDevices; ++i) {
@@ -1207,14 +1175,6 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
12071175
return result;
12081176
}
12091177

1210-
try {
1211-
// convert platform handle to loader handle
1212-
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
1213-
ur_native_factory.getInstance(*phNativeMem, dditable));
1214-
} catch (std::bad_alloc &) {
1215-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
1216-
}
1217-
12181178
return result;
12191179
}
12201180

@@ -1232,17 +1192,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
12321192
ur_result_t result = UR_RESULT_SUCCESS;
12331193

12341194
// extract platform's function pointer table
1235-
auto dditable =
1236-
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
1195+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
12371196
auto pfnBufferCreateWithNativeHandle =
12381197
dditable->ur.Mem.pfnBufferCreateWithNativeHandle;
12391198
if (nullptr == pfnBufferCreateWithNativeHandle) {
12401199
return UR_RESULT_ERROR_UNINITIALIZED;
12411200
}
12421201

1243-
// convert loader handle to platform handle
1244-
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;
1245-
12461202
// convert loader handle to platform handle
12471203
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
12481204

@@ -1282,17 +1238,13 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
12821238
ur_result_t result = UR_RESULT_SUCCESS;
12831239

12841240
// extract platform's function pointer table
1285-
auto dditable =
1286-
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
1241+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
12871242
auto pfnImageCreateWithNativeHandle =
12881243
dditable->ur.Mem.pfnImageCreateWithNativeHandle;
12891244
if (nullptr == pfnImageCreateWithNativeHandle) {
12901245
return UR_RESULT_ERROR_UNINITIALIZED;
12911246
}
12921247

1293-
// convert loader handle to platform handle
1294-
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;
1295-
12961248
// convert loader handle to platform handle
12971249
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
12981250

@@ -1528,14 +1480,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle(
15281480
return result;
15291481
}
15301482

1531-
try {
1532-
// convert platform handle to loader handle
1533-
*phNativeSampler = reinterpret_cast<ur_native_handle_t>(
1534-
ur_native_factory.getInstance(*phNativeSampler, dditable));
1535-
} catch (std::bad_alloc &) {
1536-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
1537-
}
1538-
15391483
return result;
15401484
}
15411485

@@ -1553,18 +1497,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle(
15531497
ur_result_t result = UR_RESULT_SUCCESS;
15541498

15551499
// extract platform's function pointer table
1556-
auto dditable =
1557-
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->dditable;
1500+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
15581501
auto pfnCreateWithNativeHandle =
15591502
dditable->ur.Sampler.pfnCreateWithNativeHandle;
15601503
if (nullptr == pfnCreateWithNativeHandle) {
15611504
return UR_RESULT_ERROR_UNINITIALIZED;
15621505
}
15631506

1564-
// convert loader handle to platform handle
1565-
hNativeSampler =
1566-
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->handle;
1567-
15681507
// convert loader handle to platform handle
15691508
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
15701509

@@ -2604,14 +2543,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle(
26042543
return result;
26052544
}
26062545

2607-
try {
2608-
// convert platform handle to loader handle
2609-
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(
2610-
ur_native_factory.getInstance(*phNativeProgram, dditable));
2611-
} catch (std::bad_alloc &) {
2612-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
2613-
}
2614-
26152546
return result;
26162547
}
26172548

@@ -2629,18 +2560,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
26292560
ur_result_t result = UR_RESULT_SUCCESS;
26302561

26312562
// extract platform's function pointer table
2632-
auto dditable =
2633-
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->dditable;
2563+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
26342564
auto pfnCreateWithNativeHandle =
26352565
dditable->ur.Program.pfnCreateWithNativeHandle;
26362566
if (nullptr == pfnCreateWithNativeHandle) {
26372567
return UR_RESULT_ERROR_UNINITIALIZED;
26382568
}
26392569

2640-
// convert loader handle to platform handle
2641-
hNativeProgram =
2642-
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->handle;
2643-
26442570
// convert loader handle to platform handle
26452571
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
26462572

@@ -3088,14 +3014,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle(
30883014
return result;
30893015
}
30903016

3091-
try {
3092-
// convert platform handle to loader handle
3093-
*phNativeKernel = reinterpret_cast<ur_native_handle_t>(
3094-
ur_native_factory.getInstance(*phNativeKernel, dditable));
3095-
} catch (std::bad_alloc &) {
3096-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3097-
}
3098-
30993017
return result;
31003018
}
31013019

@@ -3115,18 +3033,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
31153033
ur_result_t result = UR_RESULT_SUCCESS;
31163034

31173035
// extract platform's function pointer table
3118-
auto dditable =
3119-
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->dditable;
3036+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
31203037
auto pfnCreateWithNativeHandle =
31213038
dditable->ur.Kernel.pfnCreateWithNativeHandle;
31223039
if (nullptr == pfnCreateWithNativeHandle) {
31233040
return UR_RESULT_ERROR_UNINITIALIZED;
31243041
}
31253042

3126-
// convert loader handle to platform handle
3127-
hNativeKernel =
3128-
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->handle;
3129-
31303043
// convert loader handle to platform handle
31313044
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
31323045

@@ -3300,14 +3213,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle(
33003213
return result;
33013214
}
33023215

3303-
try {
3304-
// convert platform handle to loader handle
3305-
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
3306-
ur_native_factory.getInstance(*phNativeQueue, dditable));
3307-
} catch (std::bad_alloc &) {
3308-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3309-
}
3310-
33113216
return result;
33123217
}
33133218

@@ -3326,17 +3231,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
33263231
ur_result_t result = UR_RESULT_SUCCESS;
33273232

33283233
// extract platform's function pointer table
3329-
auto dditable =
3330-
reinterpret_cast<ur_native_object_t *>(hNativeQueue)->dditable;
3234+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
33313235
auto pfnCreateWithNativeHandle =
33323236
dditable->ur.Queue.pfnCreateWithNativeHandle;
33333237
if (nullptr == pfnCreateWithNativeHandle) {
33343238
return UR_RESULT_ERROR_UNINITIALIZED;
33353239
}
33363240

3337-
// convert loader handle to platform handle
3338-
hNativeQueue = reinterpret_cast<ur_native_object_t *>(hNativeQueue)->handle;
3339-
33403241
// convert loader handle to platform handle
33413242
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
33423243

@@ -3573,14 +3474,6 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle(
35733474
return result;
35743475
}
35753476

3576-
try {
3577-
// convert platform handle to loader handle
3578-
*phNativeEvent = reinterpret_cast<ur_native_handle_t>(
3579-
ur_native_factory.getInstance(*phNativeEvent, dditable));
3580-
} catch (std::bad_alloc &) {
3581-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3582-
}
3583-
35843477
return result;
35853478
}
35863479

@@ -3598,17 +3491,13 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
35983491
ur_result_t result = UR_RESULT_SUCCESS;
35993492

36003493
// extract platform's function pointer table
3601-
auto dditable =
3602-
reinterpret_cast<ur_native_object_t *>(hNativeEvent)->dditable;
3494+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
36033495
auto pfnCreateWithNativeHandle =
36043496
dditable->ur.Event.pfnCreateWithNativeHandle;
36053497
if (nullptr == pfnCreateWithNativeHandle) {
36063498
return UR_RESULT_ERROR_UNINITIALIZED;
36073499
}
36083500

3609-
// convert loader handle to platform handle
3610-
hNativeEvent = reinterpret_cast<ur_native_object_t *>(hNativeEvent)->handle;
3611-
36123501
// convert loader handle to platform handle
36133502
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
36143503

0 commit comments

Comments
 (0)