Skip to content

Commit 23799c9

Browse files
Fixed dpctl_device_selector for open-source DPC++, make wrap/unwrap templated
Moved dpctl_*_selector definitions into dpctl::syclinterface namespace. wrap/unwrap functions are made templated, with template parameter being C++ type to pointer of which the opaque pointer is cast. dpctl_device_selector for SYCL 2020 implements virtual call operator to enable type inference of this type as std::function<int(const device&)>. The derived classes override the call operator as appropriate. Calls to constructors of sycl::device and sycl::platform which take device selector callable are implemented differently depending on the compiler version.
1 parent 91f552f commit 23799c9

17 files changed

+384
-286
lines changed

libsyclinterface/include/dpctl_sycl_device_selector_interface.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,31 +65,31 @@ DPCTL_API
6565
__dpctl_give DPCTLSyclDeviceSelectorRef DPCTLCPUSelector_Create(void);
6666

6767
/*!
68-
* @brief Returns an opaque wrapper for sycl::ONEAPI::filter_selector object
69-
* based on the passed in filter string.
68+
* @brief Returns an opaque wrapper for sycl::ext::oneapi::filter_selector
69+
* object based on the passed in filter string.
7070
*
7171
* @param filter_str A C string providing a filter based on which to
72-
* create a device_selector.
73-
* @return An opaque pointer to a sycl::ONEAPI::filter_selector object.
72+
* create a device selector.
73+
* @return An opaque pointer to a sycl::ext::oneapi::filter_selector object.
7474
* @ingroup DeviceSelectors
7575
*/
7676
DPCTL_API
7777
__dpctl_give DPCTLSyclDeviceSelectorRef
7878
DPCTLFilterSelector_Create(__dpctl_keep const char *filter_str);
7979

8080
/*!
81-
* @brief Returns an opaque wrapper for sycl::gpu_selector object.
81+
* @brief Returns an opaque wrapper for dpctl_gpu_selector object.
8282
*
83-
* @return An opaque pointer to a sycl::gpu_selector object.
83+
* @return An opaque pointer to a dpctl_gpu_selector object.
8484
* @ingroup DeviceSelectors
8585
*/
8686
DPCTL_API
8787
__dpctl_give DPCTLSyclDeviceSelectorRef DPCTLGPUSelector_Create(void);
8888

8989
/*!
90-
* @brief Returns an opaque wrapper for sycl::host_selector object.
90+
* @brief Returns an opaque wrapper for dpctl_host_selector object.
9191
*
92-
* @return An opaque pointer to a sycl::host_selector object.
92+
* @return An opaque pointer to a dpctl_host_selector object.
9393
* @ingroup DeviceSelectors
9494
*/
9595
DPCTL_API

libsyclinterface/include/dpctl_sycl_platform_interface.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ __dpctl_give DPCTLSyclPlatformRef DPCTLPlatform_Create(void);
6363

6464
/*!
6565
* @brief Creates a new DPCTLSyclPlatformRef for a SYCL platform constructed
66-
* using the device_selector wrapped by DPCTLSyclDeviceSelectorRef.
66+
* using the dpctl_device_selector wrapped by DPCTLSyclDeviceSelectorRef.
6767
*
68-
* @param DSRef An opaque pointer to a SYCL device_selector object.
68+
* @param DSRef An opaque pointer to a SYCL dpctl_device_selector
69+
* object.
6970
* @return A new DPCTLSyclPlatformRef pointer wrapping a SYCL platform object.
7071
* @ingroup PlatformInterface
7172
*/

libsyclinterface/include/dpctl_sycl_type_casters.hpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,32 @@
3030

3131
#include "dpctl_sycl_types.h"
3232
#include <CL/sycl.hpp>
33+
#include <iostream>
3334
#include <vector>
3435

36+
namespace dpctl::syclinterface
37+
{
38+
3539
#if __SYCL_COMPILER_VERSION >= 20221020
3640

3741
class dpctl_device_selector
3842
{
3943
public:
4044
virtual ~dpctl_device_selector() = default;
41-
42-
virtual int operator()(const sycl::device &device) const = 0;
45+
static constexpr int REJECT_DEVICE = -1;
46+
virtual int operator()(const sycl::device &d) const
47+
{
48+
std::cout << "Outright rejecting "
49+
<< d.get_info<sycl::info::device::name>() << std::endl;
50+
return REJECT_DEVICE;
51+
};
4352
};
4453

4554
class dpctl_accelerator_selector : public dpctl_device_selector
4655
{
4756
public:
4857
dpctl_accelerator_selector() = default;
49-
int operator()(const sycl::device &d) const
58+
int operator()(const sycl::device &d) const override
5059
{
5160
return sycl::accelerator_selector_v(d);
5261
}
@@ -56,17 +65,19 @@ class dpctl_default_selector : public dpctl_device_selector
5665
{
5766
public:
5867
dpctl_default_selector() = default;
59-
int operator()(const sycl::device &d) const
68+
int operator()(const sycl::device &d) const override
6069
{
61-
return sycl::default_selector_v(d);
70+
auto score = sycl::default_selector_v(d);
71+
std::cout << "Got score = " << score << std::endl;
72+
return score;
6273
}
6374
};
6475

6576
class dpctl_gpu_selector : public dpctl_device_selector
6677
{
6778
public:
6879
dpctl_gpu_selector() = default;
69-
int operator()(const sycl::device &d) const
80+
int operator()(const sycl::device &d) const override
7081
{
7182
return sycl::gpu_selector_v(d);
7283
}
@@ -76,7 +87,7 @@ class dpctl_cpu_selector : public dpctl_device_selector
7687
{
7788
public:
7889
dpctl_cpu_selector() = default;
79-
int operator()(const sycl::device &d) const
90+
int operator()(const sycl::device &d) const override
8091
{
8192
return sycl::cpu_selector_v(d);
8293
}
@@ -87,7 +98,7 @@ class dpctl_filter_selector : public dpctl_device_selector
8798
public:
8899
dpctl_filter_selector(const std::string &fs) : _impl(fs) {}
89100

90-
int operator()(const sycl::device &d) const
101+
int operator()(const sycl::device &d) const override
91102
{
92103
return _impl(d);
93104
}
@@ -100,13 +111,10 @@ class dpctl_host_selector : public dpctl_device_selector
100111
{
101112
public:
102113
dpctl_host_selector() = default;
103-
int operator()(const sycl::device &) const
114+
int operator()(const sycl::device &) const override
104115
{
105-
return REJECTED_SCORE;
116+
return REJECT_DEVICE;
106117
}
107-
108-
private:
109-
constexpr static int REJECTED_SCORE = -1;
110118
};
111119

112120
#else
@@ -201,22 +209,20 @@ class dpctl_host_selector : public dpctl_device_selector
201209
#endif
202210

203211
/*!
204-
@brief Creates two convenience functions to reinterpret_cast an opaque
205-
pointer to a pointer to a Sycl type and vice-versa.
212+
@brief Creates two convenience templated functions to
213+
reinterpret_cast an opaque pointer to a pointer to a Sycl type
214+
and vice-versa.
206215
*/
207216
#define DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ty, ref) \
208-
__attribute__((unused)) inline ty *unwrap(ref P) \
217+
template <typename T, \
218+
std::enable_if_t<std::is_same<T, ty>::value, bool> = true> \
219+
__attribute__((unused)) T *unwrap(ref P) \
209220
{ \
210221
return reinterpret_cast<ty *>(P); \
211-
} \
212-
\
213-
__attribute__((unused)) inline ref wrap(const ty *P) \
214-
{ \
215-
return reinterpret_cast<ref>(const_cast<ty *>(P)); \
216222
} \
217223
template <typename T, \
218224
std::enable_if_t<std::is_same<T, ty>::value, bool> = true> \
219-
ref wrap(const ty *P) \
225+
__attribute__((unused)) ref wrap(const ty *P) \
220226
{ \
221227
return reinterpret_cast<ref>(const_cast<ty *>(P)); \
222228
}
@@ -247,3 +253,5 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclEventRef>,
247253
DPCTLEventVectorRef)
248254

249255
#endif
256+
257+
} // namespace dpctl::syclinterface

libsyclinterface/source/dpctl_sycl_context_interface.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,27 @@
3232

3333
using namespace sycl;
3434

35+
namespace
36+
{
37+
using namespace dpctl::syclinterface;
38+
} // end of anonymous namespace
39+
3540
__dpctl_give DPCTLSyclContextRef
3641
DPCTLContext_Create(__dpctl_keep const DPCTLSyclDeviceRef DRef,
3742
error_handler_callback *handler,
3843
int /**/)
3944
{
4045
DPCTLSyclContextRef CRef = nullptr;
41-
auto Device = unwrap(DRef);
46+
auto Device = unwrap<device>(DRef);
4247
if (!Device) {
4348
error_handler("Cannot create device from DPCTLSyclDeviceRef"
4449
"as input is a nullptr.",
4550
__FILE__, __func__, __LINE__);
4651
return nullptr;
4752
}
4853
try {
49-
CRef = wrap(new context(*Device, DPCTL_AsyncErrorHandler(handler)));
54+
CRef = wrap<context>(
55+
new context(*Device, DPCTL_AsyncErrorHandler(handler)));
5056
} catch (std::exception const &e) {
5157
error_handler(e, __FILE__, __func__, __LINE__);
5258
}
@@ -61,7 +67,7 @@ DPCTLContext_CreateFromDevices(__dpctl_keep const DPCTLDeviceVectorRef DVRef,
6167
{
6268
DPCTLSyclContextRef CRef = nullptr;
6369
std::vector<device> Devices;
64-
auto DeviceRefs = unwrap(DVRef);
70+
auto DeviceRefs = unwrap<std::vector<DPCTLSyclDeviceRef>>(DVRef);
6571
if (!DeviceRefs) {
6672
error_handler("Cannot create device reference from DPCTLDeviceVectorRef"
6773
"as input is a nullptr.",
@@ -71,11 +77,12 @@ DPCTLContext_CreateFromDevices(__dpctl_keep const DPCTLDeviceVectorRef DVRef,
7177
Devices.reserve(DeviceRefs->size());
7278

7379
for (auto const &DRef : *DeviceRefs) {
74-
Devices.emplace_back(*unwrap(DRef));
80+
Devices.emplace_back(*unwrap<device>(DRef));
7581
}
7682

7783
try {
78-
CRef = wrap(new context(Devices, DPCTL_AsyncErrorHandler(handler)));
84+
CRef = wrap<context>(
85+
new context(Devices, DPCTL_AsyncErrorHandler(handler)));
7986
} catch (std::exception const &e) {
8087
error_handler(e, __FILE__, __func__, __LINE__);
8188
}
@@ -91,21 +98,21 @@ bool DPCTLContext_AreEq(__dpctl_keep const DPCTLSyclContextRef CtxRef1,
9198
__LINE__);
9299
return false;
93100
}
94-
return (*unwrap(CtxRef1) == *unwrap(CtxRef2));
101+
return (*unwrap<context>(CtxRef1) == *unwrap<context>(CtxRef2));
95102
}
96103

97104
__dpctl_give DPCTLSyclContextRef
98105
DPCTLContext_Copy(__dpctl_keep const DPCTLSyclContextRef CRef)
99106
{
100-
auto Context = unwrap(CRef);
107+
auto Context = unwrap<context>(CRef);
101108
if (!Context) {
102109
error_handler("Cannot copy DPCTLSyclContextRef as input is a nullptr.",
103110
__FILE__, __func__, __LINE__);
104111
return nullptr;
105112
}
106113
try {
107114
auto CopiedContext = new context(*Context);
108-
return wrap(CopiedContext);
115+
return wrap<context>(CopiedContext);
109116
} catch (std::exception const &e) {
110117
error_handler(e, __FILE__, __func__, __LINE__);
111118
return nullptr;
@@ -115,16 +122,17 @@ DPCTLContext_Copy(__dpctl_keep const DPCTLSyclContextRef CRef)
115122
__dpctl_give DPCTLDeviceVectorRef
116123
DPCTLContext_GetDevices(__dpctl_keep const DPCTLSyclContextRef CRef)
117124
{
118-
auto Context = unwrap(CRef);
125+
auto Context = unwrap<context>(CRef);
119126
if (!Context) {
120127
error_handler("Cannot retrieve devices from DPCTLSyclContextRef as "
121128
"input is a nullptr.",
122129
__FILE__, __func__, __LINE__);
123130
return nullptr;
124131
}
125-
std::vector<DPCTLSyclDeviceRef> *DevicesVectorPtr = nullptr;
132+
using vecTy = std::vector<DPCTLSyclDeviceRef>;
133+
vecTy *DevicesVectorPtr = nullptr;
126134
try {
127-
DevicesVectorPtr = new std::vector<DPCTLSyclDeviceRef>();
135+
DevicesVectorPtr = new vecTy();
128136
} catch (std::exception const &e) {
129137
delete DevicesVectorPtr;
130138
error_handler(e, __FILE__, __func__, __LINE__);
@@ -134,9 +142,9 @@ DPCTLContext_GetDevices(__dpctl_keep const DPCTLSyclContextRef CRef)
134142
auto Devices = Context->get_devices();
135143
DevicesVectorPtr->reserve(Devices.size());
136144
for (const auto &Dev : Devices) {
137-
DevicesVectorPtr->emplace_back(wrap(new device(Dev)));
145+
DevicesVectorPtr->emplace_back(wrap<device>(new device(Dev)));
138146
}
139-
return wrap(DevicesVectorPtr);
147+
return wrap<vecTy>(DevicesVectorPtr);
140148
} catch (std::exception const &e) {
141149
delete DevicesVectorPtr;
142150
error_handler(e, __FILE__, __func__, __LINE__);
@@ -146,7 +154,7 @@ DPCTLContext_GetDevices(__dpctl_keep const DPCTLSyclContextRef CRef)
146154

147155
size_t DPCTLContext_DeviceCount(__dpctl_keep const DPCTLSyclContextRef CRef)
148156
{
149-
auto Context = unwrap(CRef);
157+
auto Context = unwrap<context>(CRef);
150158
if (!Context) {
151159
error_handler("Cannot retrieve devices from DPCTLSyclContextRef as "
152160
"input is a nullptr.",
@@ -159,7 +167,7 @@ size_t DPCTLContext_DeviceCount(__dpctl_keep const DPCTLSyclContextRef CRef)
159167

160168
bool DPCTLContext_IsHost(__dpctl_keep const DPCTLSyclContextRef CtxRef)
161169
{
162-
auto Ctx = unwrap(CtxRef);
170+
auto Ctx = unwrap<context>(CtxRef);
163171
if (Ctx) {
164172
return Ctx->is_host();
165173
}
@@ -168,7 +176,7 @@ bool DPCTLContext_IsHost(__dpctl_keep const DPCTLSyclContextRef CtxRef)
168176

169177
void DPCTLContext_Delete(__dpctl_take DPCTLSyclContextRef CtxRef)
170178
{
171-
delete unwrap(CtxRef);
179+
delete unwrap<context>(CtxRef);
172180
}
173181

174182
DPCTLSyclBackendType
@@ -178,7 +186,7 @@ DPCTLContext_GetBackend(__dpctl_keep const DPCTLSyclContextRef CtxRef)
178186
return DPCTL_UNKNOWN_BACKEND;
179187
}
180188

181-
auto BE = unwrap(CtxRef)->get_platform().get_backend();
189+
auto BE = unwrap<context>(CtxRef)->get_platform().get_backend();
182190

183191
switch (BE) {
184192
case backend::host:
@@ -197,7 +205,7 @@ DPCTLContext_GetBackend(__dpctl_keep const DPCTLSyclContextRef CtxRef)
197205
size_t DPCTLContext_Hash(__dpctl_keep const DPCTLSyclContextRef CtxRef)
198206
{
199207
if (CtxRef) {
200-
auto C = unwrap(CtxRef);
208+
auto C = unwrap<context>(CtxRef);
201209
std::hash<context> hash_fn;
202210
return hash_fn(*C);
203211
}

0 commit comments

Comments
 (0)