@@ -57,6 +57,40 @@ namespace mkl_rng_dev = oneapi::mkl::rng::device;
57
57
namespace py = pybind11;
58
58
namespace type_utils = dpctl::tensor::type_utils;
59
59
60
+ constexpr int num_methods = 2 ; // number of methods of gaussian distribution
61
+
62
+ // static mkl_rng_dev::gaussian_method get_method(const std::int8_t method) {
63
+ // switch (method) {
64
+ // case 0: return mkl_rng_dev::gaussian_method::by_default;
65
+ // case 1: return mkl_rng_dev::gaussian_method::by_default;
66
+ // default:
67
+ // throw py::value_error();
68
+ // }
69
+ // }
70
+
71
+ template <typename DataT, typename Method>
72
+ struct GaussianDistr
73
+ {
74
+ private:
75
+ const DataT mean_;
76
+ const DataT stddev_;
77
+
78
+ public:
79
+ using method_type = Method;
80
+ using result_type = DataT;
81
+ using distr_type = typename mkl_rng_dev::gaussian<DataT, Method>;
82
+
83
+ GaussianDistr (const DataT mean, const DataT stddev)
84
+ : mean_(mean), stddev_(stddev)
85
+ {
86
+ }
87
+
88
+ inline auto operator ()(void ) const
89
+ {
90
+ return distr_type (mean_, stddev_);
91
+ }
92
+ };
93
+
60
94
typedef sycl::event (*gaussian_impl_fn_ptr_t )(sycl::queue &,
61
95
const std::uint32_t ,
62
96
const double ,
@@ -65,13 +99,12 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
65
99
char *,
66
100
const std::vector<sycl::event> &);
67
101
68
- static gaussian_impl_fn_ptr_t gaussian_dispatch_vector [dpctl_td_ns::num_types];
102
+ static gaussian_impl_fn_ptr_t gaussian_dispatch_table [dpctl_td_ns::num_types][num_methods ];
69
103
70
- // template <typename DataT, typename Method = mkl_rng_dev::gaussian_method::by_default>
71
- template <typename DataT, unsigned int vec_sz, unsigned int items_per_wi>
104
+ template <typename DataT, typename Method, unsigned int vec_sz, unsigned int items_per_wi>
72
105
class gaussian_kernel ;
73
106
74
- template <typename DataT, typename Method = mkl_rng_dev::gaussian_method::by_default >
107
+ template <typename DataT, typename Method>
75
108
static sycl::event gaussian_impl (sycl::queue& exec_q,
76
109
const std::uint32_t seed,
77
110
const double mean_val,
@@ -98,20 +131,23 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
98
131
distr_event = exec_q.submit ([&](sycl::handler &cgh) {
99
132
cgh.depends_on (depends);
100
133
134
+ using GaussianDistrT = GaussianDistr<DataT, Method>;
135
+ GaussianDistrT distr (mean, stddev);
136
+
101
137
if (is_aligned<required_alignment>(out_ptr)) {
102
138
constexpr bool enable_sg_load = true ;
103
- using KernelName = gaussian_kernel<DataT, vec_sz, items_per_wi>;
139
+ using KernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
104
140
105
141
cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
106
- details::RngContigFunctor<DataT, DataT, Method , DataT, vec_sz, items_per_wi, enable_sg_load>(seed, mean, stddev , out, n));
142
+ details::RngContigFunctor<DataT, GaussianDistrT, DataT , DataT, vec_sz, items_per_wi, enable_sg_load>(seed, distr , out, n));
107
143
}
108
144
else {
109
145
constexpr bool disable_sg_load = false ;
110
- using InnerKernelName = gaussian_kernel<DataT, vec_sz, items_per_wi>;
146
+ using InnerKernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
111
147
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
112
148
113
149
cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
114
- details::RngContigFunctor<DataT, DataT, Method , DataT, vec_sz, items_per_wi, disable_sg_load>(seed, mean, stddev , out, n));
150
+ details::RngContigFunctor<DataT, GaussianDistrT, DataT , DataT, vec_sz, items_per_wi, disable_sg_load>(seed, distr , out, n));
115
151
}
116
152
});
117
153
} catch (oneapi::mkl::exception const &e) {
@@ -129,6 +165,7 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
129
165
}
130
166
131
167
std::pair<sycl::event, sycl::event> gaussian (sycl::queue exec_q,
168
+ const std::uint8_t method_id,
132
169
const std::uint32_t seed,
133
170
const double mean,
134
171
const double stddev,
@@ -166,10 +203,14 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
166
203
throw std::runtime_error (" Only population of contiguous array is supported." );
167
204
}
168
205
206
+ if (method_id >= num_methods) {
207
+ throw std::runtime_error (" Unknown method=" + std::to_string (method_id) + " for gaussian distribution." );
208
+ }
209
+
169
210
auto array_types = dpctl_td_ns::usm_ndarray_types ();
170
211
int res_type_id = array_types.typenum_to_lookup_id (res.get_typenum ());
171
212
172
- auto gaussian_fn = gaussian_dispatch_vector [res_type_id];
213
+ auto gaussian_fn = gaussian_dispatch_table [res_type_id][method_id ];
173
214
if (gaussian_fn == nullptr ) {
174
215
throw py::value_error (" No gaussian implementation defined for a required type" );
175
216
}
@@ -181,36 +222,95 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
181
222
return std::make_pair (ht_ev, gaussian_ev);
182
223
}
183
224
184
- template <typename T>
225
+ template <typename funcPtrT,
226
+ template <typename fnT, typename D, typename S> typename factory,
227
+ int _num_types,
228
+ int _num_methods>
229
+ // class DispatchTableBuilder : public dpctl_td_ns::DispatchTableBuilder<funcPtrT, factory, _num_types>
230
+ class DispatchTableBuilder /* : public dpctl_td_ns::DispatchTableBuilder<funcPtrT, factory, _num_types>*/
231
+ {
232
+ private:
233
+ template <typename dstTy>
234
+ const std::vector<funcPtrT> row_per_method () const
235
+ {
236
+ std::vector<funcPtrT> per_method = {
237
+ factory<funcPtrT, dstTy, mkl_rng_dev::gaussian_method::by_default>{}.get (),
238
+ factory<funcPtrT, dstTy, mkl_rng_dev::gaussian_method::box_muller2>{}.get (),
239
+ };
240
+ assert (per_method.size () == _num_methods);
241
+ return per_method;
242
+ }
243
+
244
+ public:
245
+ DispatchTableBuilder () = default ;
246
+ ~DispatchTableBuilder () = default ;
247
+
248
+ void populate (funcPtrT table[][_num_methods]) const
249
+ {
250
+ const auto map_by_dst_type = {row_per_method<bool >(),
251
+ row_per_method<int8_t >(),
252
+ row_per_method<uint8_t >(),
253
+ row_per_method<int16_t >(),
254
+ row_per_method<uint16_t >(),
255
+ row_per_method<int32_t >(),
256
+ row_per_method<uint32_t >(),
257
+ row_per_method<int64_t >(),
258
+ row_per_method<uint64_t >(),
259
+ row_per_method<sycl::half>(),
260
+ row_per_method<float >(),
261
+ row_per_method<double >(),
262
+ row_per_method<std::complex<float >>(),
263
+ row_per_method<std::complex<double >>()};
264
+ assert (map_by_dst_type.size () == _num_types);
265
+ int dst_id = 0 ;
266
+ for (auto &row : map_by_dst_type) {
267
+ int src_id = 0 ;
268
+ for (auto &fn_ptr : row) {
269
+ table[dst_id][src_id] = fn_ptr;
270
+ ++src_id;
271
+ }
272
+ ++dst_id;
273
+ }
274
+ }
275
+ };
276
+
277
+ template <typename Ty, typename ArgTy, typename Method, typename argMethod>
278
+ struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
279
+ std::is_same_v<Method, argMethod>>
280
+ {
281
+ static constexpr bool is_defined = true ;
282
+ };
283
+
284
+ template <typename T, typename M>
185
285
struct GaussianTypePairSupportFactory
186
286
{
187
287
static constexpr bool is_defined = std::disjunction<
188
- dpctl_td_ns::TypePairDefinedEntry<T, double , T, double >,
189
- dpctl_td_ns::TypePairDefinedEntry<T, float , T, float >,
288
+ TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::by_default>,
289
+ TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::box_muller2>,
290
+ TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::by_default>,
291
+ TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::box_muller2>,
190
292
// fall-through
191
293
dpctl_td_ns::NotDefinedEntry>::is_defined;
192
294
};
193
295
194
- template <typename fnT, typename T>
296
+ template <typename fnT, typename T, typename M >
195
297
struct GaussianContigFactory
196
298
{
197
299
fnT get ()
198
300
{
199
- if constexpr (GaussianTypePairSupportFactory<T>::is_defined) {
200
- return gaussian_impl<T>;
301
+ if constexpr (GaussianTypePairSupportFactory<T, M >::is_defined) {
302
+ return gaussian_impl<T, M >;
201
303
}
202
304
else {
203
305
return nullptr ;
204
306
}
205
307
}
206
308
};
207
309
208
- void init_gaussian_dispatch_vector (void )
310
+ void init_gaussian_dispatch_table (void )
209
311
{
210
- dpctl_td_ns::DispatchVectorBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory,
211
- dpctl_td_ns::num_types>
212
- contig;
213
- contig.populate_dispatch_vector (gaussian_dispatch_vector);
312
+ DispatchTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, dpctl_td_ns::num_types, num_methods> contig;
313
+ contig.populate (gaussian_dispatch_table);
214
314
}
215
315
} // namespace device
216
316
} // namespace rng
0 commit comments