33
33
// dpctl tensor headers
34
34
#include " kernels/alignment.hpp"
35
35
36
- using dpctl::tensor::kernels::alignment_utils::disabled_sg_loadstore_wrapper_krn;
37
- using dpctl::tensor::kernels::alignment_utils::is_aligned;
38
- using dpctl::tensor::kernels::alignment_utils::required_alignment;
39
-
40
36
#include " common_impl.hpp"
41
37
#include " gaussian.hpp"
42
38
43
39
#include " engine/engine_base.hpp"
44
40
#include " engine/engine_builder.hpp"
45
41
46
- // #include "dpnp_utils.hpp"
42
+ #include " dispatch/matrix.hpp"
43
+ #include " dispatch/table_builder.hpp"
44
+
47
45
48
46
namespace dpnp
49
47
{
@@ -55,26 +53,31 @@ namespace rng
55
53
{
56
54
namespace device
57
55
{
56
+ namespace dpctl_krn_ns = dpctl::tensor::kernels::alignment_utils;
58
57
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
59
58
namespace mkl_rng_dev = oneapi::mkl::rng::device;
60
59
namespace py = pybind11;
61
60
namespace type_utils = dpctl::tensor::type_utils;
62
61
62
+ using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
63
+ using dpctl_krn_ns::is_aligned;
64
+ using dpctl_krn_ns::required_alignment;
65
+
63
66
constexpr int no_of_methods = 2 ; // number of methods of gaussian distribution
64
67
65
68
template <typename DataT, typename Method>
66
- struct GaussianDistr
69
+ struct DistributorBuilder
67
70
{
68
71
private:
69
72
const DataT mean_;
70
73
const DataT stddev_;
71
74
72
75
public:
73
- using method_type = Method;
74
76
using result_type = DataT;
77
+ using method_type = Method;
75
78
using distr_type = typename mkl_rng_dev::gaussian<DataT, Method>;
76
79
77
- GaussianDistr (const DataT mean, const DataT stddev)
80
+ DistributorBuilder (const DataT mean, const DataT stddev)
78
81
: mean_(mean), stddev_(stddev)
79
82
{
80
83
}
@@ -128,23 +131,23 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
128
131
EngineBuilderT eng_builder (engine);
129
132
eng_builder.print (); // TODO: remove
130
133
131
- using GaussianDistrT = GaussianDistr <DataT, Method>;
132
- GaussianDistrT distr (mean, stddev);
134
+ using DistributorBuilderT = DistributorBuilder <DataT, Method>;
135
+ DistributorBuilderT dist_builder (mean, stddev);
133
136
134
137
if (is_aligned<required_alignment>(out_ptr)) {
135
138
constexpr bool enable_sg_load = true ;
136
139
using KernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
137
140
138
141
cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
139
- details::RngContigFunctor<EngineBuilderT, DataT, GaussianDistrT, items_per_wi, enable_sg_load>(eng_builder, distr , out, n));
142
+ details::RngContigFunctor<EngineBuilderT, DistributorBuilderT, items_per_wi, enable_sg_load>(eng_builder, dist_builder , out, n));
140
143
}
141
144
else {
142
145
constexpr bool disable_sg_load = false ;
143
146
using InnerKernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
144
147
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
145
148
146
149
cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
147
- details::RngContigFunctor<EngineBuilderT, DataT, GaussianDistrT, items_per_wi, disable_sg_load>(eng_builder, distr , out, n));
150
+ details::RngContigFunctor<EngineBuilderT, DistributorBuilderT, items_per_wi, disable_sg_load>(eng_builder, dist_builder , out, n));
148
151
}
149
152
});
150
153
} catch (oneapi::mkl::exception const &e) {
@@ -225,97 +228,12 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
225
228
return std::make_pair (ht_ev, gaussian_ev);
226
229
}
227
230
228
- template <typename funcPtrT,
229
- template <typename fnT, typename E, typename T, typename M> typename factory,
230
- int _no_of_engines,
231
- int _no_of_types,
232
- int _no_of_methods>
233
- class Dispatch3DTableBuilder
234
- {
235
- private:
236
- template <typename E, typename T>
237
- const std::vector<funcPtrT> row_per_method () const
238
- {
239
- std::vector<funcPtrT> per_method = {
240
- factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}.get (),
241
- factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}.get (),
242
- };
243
- assert (per_method.size () == _no_of_methods);
244
- return per_method;
245
- }
246
-
247
- template <typename E>
248
- auto table_per_type_and_method () const
249
- {
250
- std::vector<std::vector<funcPtrT>>
251
- table_by_type = {row_per_method<E, bool >(),
252
- row_per_method<E, int8_t >(),
253
- row_per_method<E, uint8_t >(),
254
- row_per_method<E, int16_t >(),
255
- row_per_method<E, uint16_t >(),
256
- row_per_method<E, int32_t >(),
257
- row_per_method<E, uint32_t >(),
258
- row_per_method<E, int64_t >(),
259
- row_per_method<E, uint64_t >(),
260
- row_per_method<E, sycl::half>(),
261
- row_per_method<E, float >(),
262
- row_per_method<E, double >(),
263
- row_per_method<E, std::complex<float >>(),
264
- row_per_method<E, std::complex<double >>()};
265
- assert (table_by_type.size () == _no_of_types);
266
- return table_by_type;
267
- }
268
-
269
- public:
270
- Dispatch3DTableBuilder () = default ;
271
- ~Dispatch3DTableBuilder () = default ;
272
-
273
- void populate (funcPtrT table[][_no_of_types][_no_of_methods]) const
274
- {
275
- const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<8 >>()};
276
- assert (map_by_engine.size () == _no_of_engines);
277
-
278
- std::uint16_t engine_id = 0 ;
279
- for (auto &table_by_type : map_by_engine) {
280
- std::uint16_t type_id = 0 ;
281
- for (auto &row_by_method : table_by_type) {
282
- std::uint16_t method_id = 0 ;
283
- for (auto &fn_ptr : row_by_method) {
284
- table[engine_id][type_id][method_id] = fn_ptr;
285
- ++method_id;
286
- }
287
- ++type_id;
288
- }
289
- ++engine_id;
290
- }
291
- }
292
- };
293
-
294
- template <typename Ty, typename ArgTy, typename Method, typename argMethod>
295
- struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
296
- std::is_same_v<Method, argMethod>>
297
- {
298
- static constexpr bool is_defined = true ;
299
- };
300
-
301
- template <typename T, typename M>
302
- struct GaussianTypePairSupportFactory
303
- {
304
- static constexpr bool is_defined = std::disjunction<
305
- TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::by_default>,
306
- TypePairDefinedEntry<T, double , M, mkl_rng_dev::gaussian_method::box_muller2>,
307
- TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::by_default>,
308
- TypePairDefinedEntry<T, float , M, mkl_rng_dev::gaussian_method::box_muller2>,
309
- // fall-through
310
- dpctl_td_ns::NotDefinedEntry>::is_defined;
311
- };
312
-
313
231
template <typename fnT, typename E, typename T, typename M>
314
232
struct GaussianContigFactory
315
233
{
316
234
fnT get ()
317
235
{
318
- if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
236
+ if constexpr (dispatch:: GaussianTypePairSupportFactory<T, M>::is_defined) {
319
237
return gaussian_impl<E, T, M>;
320
238
}
321
239
else {
@@ -326,7 +244,7 @@ struct GaussianContigFactory
326
244
327
245
void init_gaussian_dispatch_table (void )
328
246
{
329
- Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
247
+ dispatch:: Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t , GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
330
248
contig.populate (gaussian_dispatch_table);
331
249
}
332
250
} // namespace device
0 commit comments