Skip to content

Commit 263d128

Browse files
committed
UR port plugin first pass
1 parent 6cd6317 commit 263d128

File tree

9 files changed

+526
-5
lines changed

9 files changed

+526
-5
lines changed

sycl/include/sycl/backend.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
#include <type_traits> // for enable_if_t
5656
#include <vector> // for vector
5757

58+
#include <ur_api.h>
59+
5860
namespace sycl {
5961
inline namespace _V1 {
6062

@@ -64,6 +66,7 @@ namespace detail {
6466
enum class backend_errc : unsigned int {};
6567

6668
// Convert from PI backend to SYCL backend enum
69+
backend convertUrBackend(ur_platform_backend_t UrBackend);
6770
backend convertBackend(pi_platform_backend PiBackend);
6871
} // namespace detail
6972

sycl/include/sycl/detail/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ inline std::string codeToString(pi_int32 code) {
191191
#define __SYCL_REPORT_PI_ERR_TO_EXC(expr, exc, str) \
192192
{ \
193193
auto code = expr; \
194-
if (code != PI_SUCCESS) { \
194+
if (code != UR_RESULT_SUCCESS) { \
195195
std::string err_str = \
196196
str ? "\n" + std::string(str) + "\n" : std::string{}; \
197197
throw exc(__SYCL_PI_ERROR_REPORT + sycl::detail::codeToString(code) + \
@@ -211,7 +211,7 @@ inline std::string codeToString(pi_int32 code) {
211211
#define __SYCL_REPORT_ERR_TO_EXC_VIA_ERRC(expr, errc) \
212212
{ \
213213
auto code = expr; \
214-
if (code != PI_SUCCESS) { \
214+
if (code != UR_RESULT_SUCCESS) { \
215215
throw sycl::exception(sycl::make_error_code(errc), \
216216
__SYCL_PI_ERROR_REPORT + \
217217
sycl::detail::codeToString(code)); \

sycl/include/sycl/detail/pi.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
1818
#include <sycl/detail/os_util.hpp> // for __SYCL_RT_OS_LINUX
1919
#include <sycl/detail/pi.h> // for piContextCreate, piContextGetInfo
20-
20+
//
2121
#include <cstdint> // for uint64_t, uint32_t
2222
#include <memory> // for shared_ptr
2323
#include <sstream> // for operator<<, basic_ostream, string...
@@ -44,9 +44,18 @@ enum class PiApiKind {
4444
#define _PI_API(api) api,
4545
#include <sycl/detail/pi.def>
4646
};
47+
48+
enum class UrApiKind {
49+
#define _UR_API(api) api,
50+
#include <sycl/detail/ur.def>
51+
};
52+
4753
class plugin;
4854
using PluginPtr = std::shared_ptr<plugin>;
4955

56+
class urPlugin;
57+
using UrPluginPtr = std::shared_ptr<urPlugin>;
58+
5059
template <sycl::backend BE>
5160
__SYCL_EXPORT void *getPluginOpaqueData(void *opaquedata_arg);
5261

@@ -208,6 +217,7 @@ extern std::shared_ptr<plugin> GlobalPlugin;
208217

209218
// Performs PI one-time initialization.
210219
std::vector<PluginPtr> &initialize();
220+
std::vector<UrPluginPtr> &initializeUr();
211221

212222
// Get the plugin serving given backend.
213223
template <backend BE> __SYCL_EXPORT const PluginPtr &getPlugin();
@@ -224,6 +234,19 @@ template <PiApiKind PiApiOffset> struct PiFuncInfo {};
224234
} \
225235
};
226236
#include <sycl/detail/pi.def>
237+
/*
238+
// Utility Functions to get Function Name for a PI Api.
239+
template <UrApiKind UrApiOffset> struct UrFuncInfo {};
240+
241+
#define _UR_API(api) \
242+
template <> struct UrFuncInfo<UrApiKind::api> { \
243+
inline const char *getFuncName() { return #api; } \
244+
//inline FuncPtrT getFuncPtr(UrPlugin MPlugin) { \
245+
// return MPlugin.PiFunctionTable.api; \
246+
//} \
247+
};
248+
#include <sycl/detail/ur.def>
249+
*/
227250

228251
/// Emits an XPTI trace before a PI API call is made
229252
/// \param FName The name of the PI API call

sycl/include/sycl/detail/ur.def

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
//==------------ ur.def Plugin Interface list of API -----------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _UR_API
10+
#error Undefined _UR_API macro expansion
11+
#endif
12+
13+
// The list of all PI interfaces wrapped with _UR_API macro.
14+
// This is for convinience of doing same thing for all interfaces, e.g.
15+
// declare, define, initialize.
16+
//
17+
// This list is used to define PiAurKind enum, which is part of ernal
18+
// interface. To avoid ABI breakage, please, add new entries to the end of the
19+
// list.
20+
//
21+
// Platform
22+
_UR_API(urPlatformGet)
23+
_UR_API(urPlatformGetInfo)
24+
_UR_API(urPlatformGetNativeHandle)
25+
_UR_API(urPlatformCreateWithNativeHandle)
26+
// Device
27+
_UR_API(urDeviceGet)
28+
_UR_API(urDeviceGetInfo)
29+
_UR_API(urDevicePartition)
30+
_UR_API(urDeviceRetain)
31+
_UR_API(urDeviceRelease)
32+
_UR_API(urDeviceSelectBinary)
33+
_UR_API(urDeviceGetNativeHandle)
34+
_UR_API(urDeviceCreateWithNativeHandle)
35+
// Cont
36+
_UR_API(urContextCreate)
37+
_UR_API(urContextGetInfo)
38+
_UR_API(urContextRetain)
39+
_UR_API(urContextRelease)
40+
_UR_API(urContextSetExtendedDeleter)
41+
_UR_API(urContextGetNativeHandle)
42+
_UR_API(urContextCreateWithNativeHandle)
43+
// Queue
44+
_UR_API(urQueueCreate)
45+
_UR_API(urQueueGetInfo)
46+
_UR_API(urQueueFinish)
47+
_UR_API(urQueueFlush)
48+
_UR_API(urQueueRetain)
49+
_UR_API(urQueueRelease)
50+
_UR_API(urQueueGetNativeHandle)
51+
_UR_API(urQueueCreateWithNativeHandle)
52+
// Memory
53+
_UR_API(urMemBufferCreate)
54+
_UR_API(urMemImageCreate)
55+
_UR_API(urMemGetInfo)
56+
_UR_API(urMemImageGetInfo)
57+
_UR_API(urMemRetain)
58+
_UR_API(urMemRelease)
59+
_UR_API(urMemBufferPartition)
60+
_UR_API(urMemGetNativeHandle)
61+
_UR_API(urMemBufferCreateWithNativeHandle)
62+
_UR_API(urMemImageCreateWithNativeHandle)
63+
// Program
64+
_UR_API(urProgramCreateWithIL)
65+
_UR_API(urProgramCreateWithBinary)
66+
_UR_API(urProgramGetInfo)
67+
_UR_API(urProgramCompile)
68+
_UR_API(urProgramBuild)
69+
_UR_API(urProgramLink)
70+
_UR_API(urProgramGetBuildInfo)
71+
_UR_API(urProgramRetain)
72+
_UR_API(urProgramRelease)
73+
_UR_API(urProgramSetSpecializationConstants)
74+
_UR_API(urProgramGetNativeHandle)
75+
_UR_API(urProgramCreateWithNativeHandle)
76+
// Kernel
77+
_UR_API(urKernelCreate)
78+
_UR_API(urKernelSetArgValue)
79+
_UR_API(urKernelSetArgLocal)
80+
_UR_API(urKernelGetInfo)
81+
_UR_API(urKernelGetGroupInfo)
82+
_UR_API(urKernelGetSubGroupInfo)
83+
_UR_API(urKernelRetain)
84+
_UR_API(urKernelRelease)
85+
_UR_API(urKernelSetArgPointer)
86+
_UR_API(urKernelSetExecInfo)
87+
_UR_API(urKernelSetArgSampler)
88+
_UR_API(urKernelSetArgMemObj)
89+
_UR_API(urKernelCreateWithNativeHandle)
90+
_UR_API(urKernelGetNativeHandle)
91+
// Event
92+
_UR_API(urEventGetInfo)
93+
_UR_API(urEventGetProfilingInfo)
94+
_UR_API(urEventWait)
95+
_UR_API(urEventSetCallback)
96+
_UR_API(urEventRetain)
97+
_UR_API(urEventRelease)
98+
_UR_API(urEventGetNativeHandle)
99+
_UR_API(urEventCreateWithNativeHandle)
100+
// Sampler
101+
_UR_API(urSamplerCreate)
102+
_UR_API(urSamplerGetInfo)
103+
_UR_API(urSamplerRetain)
104+
_UR_API(urSamplerRelease)
105+
// Queue commands
106+
_UR_API(urEnqueueKernelLaunch)
107+
_UR_API(urEnqueueEventsWait)
108+
_UR_API(urEnqueueEventsWaitWithBarrier)
109+
_UR_API(urEnqueueMemBufferRead)
110+
_UR_API(urEnqueueMemBufferReadRect)
111+
_UR_API(urEnqueueMemBufferWrite)
112+
_UR_API(urEnqueueMemBufferWriteRect)
113+
_UR_API(urEnqueueMemBufferCopy)
114+
_UR_API(urEnqueueMemBufferCopyRect)
115+
_UR_API(urEnqueueMemBufferFill)
116+
_UR_API(urEnqueueMemImageRead)
117+
_UR_API(urEnqueueMemImageWrite)
118+
_UR_API(urEnqueueMemImageCopy)
119+
_UR_API(urEnqueueMemBufferMap)
120+
_UR_API(urEnqueueMemUnmap)
121+
// USM
122+
_UR_API(urUSMHostAlloc)
123+
_UR_API(urUSMDeviceAlloc)
124+
_UR_API(urUSMSharedAlloc)
125+
_UR_API(urUSMFree)
126+
_UR_API(urEnqueueUSMFill)
127+
_UR_API(urEnqueueUSMMemcpy)
128+
_UR_API(urEnqueueUSMPrefetch)
129+
_UR_API(urEnqueueUSMAdvise)
130+
_UR_API(urUSMGetMemAllocInfo)
131+
// Host urpes
132+
_UR_API(urEnqueueReadHostPipe)
133+
_UR_API(urEnqueueWriteHostPipe)
134+
135+
_UR_API(urAdapterGetLastError)
136+
137+
_UR_API(urEnqueueUSMFill2D)
138+
_UR_API(urEnqueueUSMMemcpy2D)
139+
140+
_UR_API(urDeviceGetGlobalTimestamps)
141+
142+
/*
143+
// Device global variable
144+
_UR_API(urEnqueueDeviceGlobalVariableWrite)
145+
_UR_API(urEnqueueDeviceGlobalVariableRead)
146+
147+
_UR_API(urPluginGetBackendOption)
148+
149+
_UR_API(urEnablePeerAccess)
150+
_UR_API(urDisablePeerAccess)
151+
_UR_API(urPeerAccessGetInfo)
152+
153+
// USM import/release APIs
154+
_UR_API(urUSMImport)
155+
_UR_API(urUSMRelease)
156+
157+
// command-buffer Extension
158+
_UR_API(urCommandBufferCreate)
159+
_UR_API(urCommandBufferRetain)
160+
_UR_API(urCommandBufferRelease)
161+
_UR_API(urCommandBufferFinalize)
162+
_UR_API(urCommandBufferNDRangeKernel)
163+
_UR_API(urCommandBufferMemcpyUSM)
164+
_UR_API(urCommandBufferMemBufferCopy)
165+
_UR_API(urCommandBufferMemBufferCopyRect)
166+
_UR_API(urCommandBufferMemBufferWrite)
167+
_UR_API(urCommandBufferMemBufferWriteRect)
168+
_UR_API(urCommandBufferMemBufferRead)
169+
_UR_API(urCommandBufferMemBufferReadRect)
170+
_UR_API(urCommandBufferMemBufferFill)
171+
_UR_API(urCommandBufferFillUSM)
172+
_UR_API(urCommandBufferPrefetchUSM)
173+
_UR_API(urCommandBufferAdviseUSM)
174+
_UR_API(urEnqueueCommandBuffer)
175+
176+
_UR_API(urUSMPitchedAlloc)
177+
178+
// Bindless Images
179+
_UR_API(urMemUnsampledImageHandleDestroy)
180+
_UR_API(urMemSampledImageHandleDestroy)
181+
_UR_API(urBindlessImageSamplerCreate)
182+
_UR_API(urMemImageAllocate)
183+
_UR_API(urMemImageFree)
184+
_UR_API(urMemUnsampledImageCreate)
185+
_UR_API(urMemSampledImageCreate)
186+
_UR_API(urMemImageCopy)
187+
_UR_API(urMemImageGetInfo)
188+
_UR_API(urMemMipmapGetLevel)
189+
_UR_API(urMemMipmapFree)
190+
191+
// Interop
192+
_UR_API(urMemImportOpaqueFD)
193+
_UR_API(urMemReleaseInterop)
194+
_UR_API(urMemMapExternalArray)
195+
_UR_API(urImportExternalSemaphoreOpaqueFD)
196+
_UR_API(urDestroyExternalSemaphore)
197+
_UR_API(urWaitExternalSemaphore)
198+
_UR_API(urSignalExternalSemaphore)
199+
*/
200+
#undef _UR_API

sycl/source/backend.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,24 @@ backend convertBackend(pi_platform_backend PiBackend) {
6767
PI_ERROR_INVALID_OPERATION};
6868
}
6969

70+
backend convertUrBackend(ur_platform_backend_t UrBackend) {
71+
switch (UrBackend) {
72+
case UR_PLATFORM_BACKEND_LEVEL_ZERO:
73+
return backend::ext_oneapi_level_zero;
74+
case UR_PLATFORM_BACKEND_OPENCL:
75+
return backend::opencl;
76+
case UR_PLATFORM_BACKEND_CUDA:
77+
return backend::ext_oneapi_cuda;
78+
case UR_PLATFORM_BACKEND_HIP:
79+
return backend::ext_oneapi_hip;
80+
case UR_PLATFORM_BACKEND_NATIVE_CPU:
81+
return backend::ext_oneapi_native_cpu;
82+
default:
83+
// no idea what to do here
84+
return backend::all;
85+
}
86+
}
87+
7088
platform make_platform(pi_native_handle NativeHandle, backend Backend) {
7189
const auto &Plugin = getPlugin(Backend);
7290

sycl/source/detail/global_handler.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ std::vector<PluginPtr> &GlobalHandler::getPlugins() {
202202
enableOnCrashStackPrinting();
203203
return getOrCreate(MPlugins);
204204
}
205+
std::vector<UrPluginPtr> &GlobalHandler::getUrPlugins() {
206+
enableOnCrashStackPrinting();
207+
return getOrCreate(MUrPlugins);
208+
}
205209

206210
ods_target_list &
207211
GlobalHandler::getOneapiDeviceSelectorTargets(const std::string &InitValue) {
@@ -255,6 +259,14 @@ void GlobalHandler::unloadPlugins() {
255259
}
256260
// Clear after unload to avoid uses after unload.
257261
getPlugins().clear();
262+
if (MUrPlugins.Inst) {
263+
for (const auto &Plugin : getUrPlugins()) {
264+
Plugin->release();
265+
}
266+
}
267+
268+
// Clear after unload to avoid uses after unload.
269+
getUrPlugins().clear();
258270
}
259271

260272
void GlobalHandler::prepareSchedulerToRelease(bool Blocking) {

sycl/source/detail/global_handler.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ class Scheduler;
2323
class ProgramManager;
2424
class Sync;
2525
class plugin;
26+
class urPlugin;
2627
class ods_target_list;
2728
class XPTIRegistry;
2829
class ThreadPool;
2930

3031
using PlatformImplPtr = std::shared_ptr<platform_impl>;
3132
using ContextImplPtr = std::shared_ptr<context_impl>;
3233
using PluginPtr = std::shared_ptr<plugin>;
34+
using UrPluginPtr = std::shared_ptr<urPlugin>;
3335

3436
/// Wrapper class for global data structures with non-trivial destructors.
3537
///
@@ -69,6 +71,7 @@ class GlobalHandler {
6971
std::mutex &getPlatformMapMutex();
7072
std::mutex &getFilterMutex();
7173
std::vector<PluginPtr> &getPlugins();
74+
std::vector<UrPluginPtr> &getUrPlugins();
7275
ods_target_list &getOneapiDeviceSelectorTargets(const std::string &InitValue);
7376
XPTIRegistry &getXPTIRegistry();
7477
ThreadPool &getHostTaskThreadPool();
@@ -119,6 +122,7 @@ class GlobalHandler {
119122
InstWithLock<std::mutex> MPlatformMapMutex;
120123
InstWithLock<std::mutex> MFilterMutex;
121124
InstWithLock<std::vector<PluginPtr>> MPlugins;
125+
InstWithLock<std::vector<UrPluginPtr>> MUrPlugins;
122126
InstWithLock<ods_target_list> MOneapiDeviceSelectorTargets;
123127
InstWithLock<XPTIRegistry> MXPTIRegistry;
124128
// Thread pool for host task and event callbacks execution

0 commit comments

Comments
 (0)