@@ -22,19 +22,29 @@ struct CUDAContext;
22
22
struct DeviceSym {
23
23
static auto constexpr CPU () { return " cpu" ; }
24
24
static auto constexpr CUDA () { return " cuda" ; }
25
+ static auto constexpr SYCL_default () { return " sycl" ; }
26
+ static auto constexpr SYCL_CPU () { return " sycl:cpu" ; }
27
+ static auto constexpr SYCL_GPU () { return " sycl:gpu" ; }
25
28
};
26
29
27
30
/* *
28
31
* @brief A type for device ordinal. The type is packed into 32-bit for efficient use in
29
32
* viewing types like `linalg::TensorView`.
30
33
*/
34
+ constexpr static bst_d_ordinal_t kDefaultOrdinal = -1 ;
31
35
struct DeviceOrd {
32
- enum Type : std::int16_t { kCPU = 0 , kCUDA = 1 } device{kCPU };
33
- // CUDA device ordinal.
34
- bst_d_ordinal_t ordinal{- 1 };
36
+ enum Type : std::int16_t { kCPU = 0 , kCUDA = 1 , kSyclDefault = 2 , kSyclCPU = 3 , kSyclGPU = 4 } device{kCPU };
37
+ // CUDA or Sycl device ordinal.
38
+ bst_d_ordinal_t ordinal{kDefaultOrdinal };
35
39
36
40
[[nodiscard]] bool IsCUDA () const { return device == kCUDA ; }
37
41
[[nodiscard]] bool IsCPU () const { return device == kCPU ; }
42
+ [[nodiscard]] bool IsSyclDefault () const { return device == kSyclDefault ; }
43
+ [[nodiscard]] bool IsSyclCPU () const { return device == kSyclCPU ; }
44
+ [[nodiscard]] bool IsSyclGPU () const { return device == kSyclGPU ; }
45
+ [[nodiscard]] bool IsSycl () const { return (IsSyclDefault () ||
46
+ IsSyclCPU () ||
47
+ IsSyclGPU ()); }
38
48
39
49
DeviceOrd () = default ;
40
50
constexpr DeviceOrd (Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
@@ -47,14 +57,35 @@ struct DeviceOrd {
47
57
/* *
48
58
* @brief Constructor for CPU.
49
59
*/
50
- [[nodiscard]] constexpr static auto CPU () { return DeviceOrd{kCPU , - 1 }; }
60
+ [[nodiscard]] constexpr static auto CPU () { return DeviceOrd{kCPU , kDefaultOrdinal }; }
51
61
/* *
52
62
* @brief Constructor for CUDA device.
53
63
*
54
64
* @param ordinal CUDA device ordinal.
55
65
*/
56
66
[[nodiscard]] static auto CUDA (bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA , ordinal}; }
57
67
68
+ /* *
69
+ * @brief Constructor for SYCL.
70
+ *
71
+ * @param ordinal SYCL device ordinal.
72
+ */
73
+ [[nodiscard]] constexpr static auto SYCL_default (bst_d_ordinal_t ordinal = kDefaultOrdinal ) { return DeviceOrd{kSyclDefault , ordinal}; }
74
+
75
+ /* *
76
+ * @brief Constructor for SYCL CPU.
77
+ *
78
+ * @param ordinal SYCL CPU device ordinal.
79
+ */
80
+ [[nodiscard]] constexpr static auto SYCL_CPU (bst_d_ordinal_t ordinal = kDefaultOrdinal ) { return DeviceOrd{kSyclCPU , ordinal}; }
81
+
82
+ /* *
83
+ * @brief Constructor for SYCL GPU.
84
+ *
85
+ * @param ordinal SYCL GPU device ordinal.
86
+ */
87
+ [[nodiscard]] constexpr static auto SYCL_GPU (bst_d_ordinal_t ordinal = kDefaultOrdinal ) { return DeviceOrd{kSyclGPU , ordinal}; }
88
+
58
89
[[nodiscard]] bool operator ==(DeviceOrd const & that) const {
59
90
return device == that.device && ordinal == that.ordinal ;
60
91
}
@@ -68,6 +99,12 @@ struct DeviceOrd {
68
99
return DeviceSym::CPU ();
69
100
case DeviceOrd::kCUDA :
70
101
return DeviceSym::CUDA () + (' :' + std::to_string (ordinal));
102
+ case DeviceOrd::kSyclDefault :
103
+ return DeviceSym::SYCL_default () + (' :' + std::to_string (ordinal));
104
+ case DeviceOrd::kSyclCPU :
105
+ return DeviceSym::SYCL_CPU () + (' :' + std::to_string (ordinal));
106
+ case DeviceOrd::kSyclGPU :
107
+ return DeviceSym::SYCL_GPU () + (' :' + std::to_string (ordinal));
71
108
default : {
72
109
LOG (FATAL) << " Unknown device." ;
73
110
return " " ;
@@ -135,6 +172,25 @@ struct Context : public XGBoostParameter<Context> {
135
172
* @brief Is XGBoost running on a CUDA device?
136
173
*/
137
174
[[nodiscard]] bool IsCUDA () const { return Device ().IsCUDA (); }
175
+ /* *
176
+ * @brief Is XGBoost running on the default SYCL device?
177
+ */
178
+ [[nodiscard]] bool IsSyclDefault () const { return Device ().IsSyclDefault (); }
179
+ /* *
180
+ * @brief Is XGBoost running on a SYCL CPU?
181
+ */
182
+ [[nodiscard]] bool IsSyclCPU () const { return Device ().IsSyclCPU (); }
183
+ /* *
184
+ * @brief Is XGBoost running on a SYCL GPU?
185
+ */
186
+ [[nodiscard]] bool IsSyclGPU () const { return Device ().IsSyclGPU (); }
187
+ /* *
188
+ * @brief Is XGBoost running on any SYCL device?
189
+ */
190
+ [[nodiscard]] bool IsSycl () const { return IsSyclDefault ()
191
+ || IsSyclCPU ()
192
+ || IsSyclGPU (); }
193
+
138
194
/* *
139
195
* @brief Get the current device and ordinal.
140
196
*/
@@ -171,6 +227,29 @@ struct Context : public XGBoostParameter<Context> {
171
227
/* *
172
228
* @brief Call function based on the current device.
173
229
*/
230
+ template <typename CPUFn, typename CUDAFn, typename SYCLFn>
231
+ decltype (auto ) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
232
+ static_assert (std::is_same_v<std::invoke_result_t <CPUFn>, std::invoke_result_t <CUDAFn>>);
233
+ switch (this ->Device ().device ) {
234
+ case DeviceOrd::kCPU :
235
+ return cpu_fn ();
236
+ case DeviceOrd::kCUDA :
237
+ return cuda_fn ();
238
+ case DeviceOrd::kSyclDefault :
239
+ return sycl_fn ();
240
+ case DeviceOrd::kSyclCPU :
241
+ return sycl_fn ();
242
+ case DeviceOrd::kSyclGPU :
243
+ return sycl_fn ();
244
+ default :
245
+ // Do not use the device name as this is likely an internal error, the name
246
+ // wouldn't be valid.
247
+ LOG (FATAL) << " Unknown device type:"
248
+ << static_cast <std::underlying_type_t <DeviceOrd::Type>>(this ->Device ().device );
249
+ break ;
250
+ }
251
+ return std::invoke_result_t <CPUFn>();
252
+ }
174
253
template <typename CPUFn, typename CUDAFn>
175
254
decltype (auto ) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const {
176
255
static_assert (std::is_same_v<std::invoke_result_t <CPUFn>, std::invoke_result_t <CUDAFn>>);
@@ -179,6 +258,12 @@ struct Context : public XGBoostParameter<Context> {
179
258
return cpu_fn ();
180
259
case DeviceOrd::kCUDA :
181
260
return cuda_fn ();
261
+ case DeviceOrd::kSyclDefault :
262
+ LOG (FATAL) << " The requested feature is not implemented for sycl yet" ;
263
+ case DeviceOrd::kSyclCPU :
264
+ LOG (FATAL) << " The requested feature is not implemented for sycl yet" ;
265
+ case DeviceOrd::kSyclGPU :
266
+ LOG (FATAL) << " The requested feature is not implemented for sycl yet" ;
182
267
default :
183
268
// Do not use the device name as this is likely an internal error, the name
184
269
// wouldn't be valid.
@@ -213,7 +298,9 @@ struct Context : public XGBoostParameter<Context> {
213
298
void SetDeviceOrdinal (Args const & kwargs);
214
299
Context& SetDevice (DeviceOrd d) {
215
300
this ->device_ = d;
216
- this ->gpu_id = d.ordinal ; // this can be removed once we move away from `gpu_id`.
301
+ if (d.IsCUDA ()) {
302
+ this ->gpu_id = d.ordinal ; // this can be removed once we move away from `gpu_id`.
303
+ }
217
304
this ->device = d.Name ();
218
305
return *this ;
219
306
}
0 commit comments