Skip to content

Commit 84074fd

Browse files
committed
Added destribution method dispatching
1 parent 114ab1d commit 84074fd

File tree

4 files changed

+138
-34
lines changed

4 files changed

+138
-34
lines changed

dpnp/backend/extensions/rng/device/common_impl.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ namespace mkl_rng_dev = oneapi::mkl::rng::device;
5151

5252
/*! @brief Functor for unary function evaluation on contiguous array */
5353
template <typename DataT,
54+
typename GaussianDistrT,
5455
typename ResT = DataT,
55-
typename Method = mkl_rng_dev::gaussian_method::by_default,
5656
typename UnaryOperatorT = ResT,
5757
unsigned int vec_sz = 8,
5858
unsigned int items_per_wi = 4,
@@ -61,14 +61,14 @@ struct RngContigFunctor
6161
{
6262
private:
6363
const std::uint32_t seed_;
64-
const DataT mean_;
65-
const DataT stddev_;
64+
GaussianDistrT distr_;
6665
ResT * const res_ = nullptr;
6766
const size_t nelems_;
6867

6968
public:
70-
RngContigFunctor(const std::uint32_t seed, const DataT mean, const DataT stddev, ResT *res, const size_t n_elems)
71-
: seed_(seed), mean_(mean), stddev_(stddev), res_(res), nelems_(n_elems)
69+
70+
RngContigFunctor(const std::uint32_t seed, GaussianDistrT& distr, ResT *res, const size_t n_elems)
71+
: seed_(seed), distr_(distr), res_(res), nelems_(n_elems)
7272
{
7373
}
7474

@@ -80,8 +80,11 @@ struct RngContigFunctor
8080
const std::uint8_t sg_size = sg.get_local_range()[0];
8181
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
8282

83-
auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id); // offset is questionable...
84-
mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
83+
using EngineT = typename mkl_rng_dev::mrg32k3a<vec_sz>;
84+
auto engine = EngineT(seed_, nelems_ * global_id); // offset is questionable...
85+
86+
using DistrT = typename GaussianDistrT::distr_type;
87+
DistrT distr = distr_();
8588

8689
if constexpr (enable_sg_load) {
8790
const size_t base = items_per_wi * vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
@@ -92,13 +95,13 @@ struct RngContigFunctor
9295
size_t offset = base + static_cast<size_t>(it) * static_cast<size_t>(sg_size);
9396
auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
9497

95-
sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate(distr, engine);
98+
sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate<DistrT, EngineT>(distr, engine);
9699
sg.store<vec_sz>(out_multi_ptr, rng_val_vec);
97100
}
98101
}
99102
else {
100103
for (size_t offset = base + sg.get_local_id()[0]; offset < nelems_; offset += sg_size) {
101-
res_[offset] = mkl_rng_dev::generate_single(distr, engine);
104+
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
102105
}
103106
}
104107
}
@@ -108,7 +111,7 @@ struct RngContigFunctor
108111
base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
109112
for (size_t offset = base; offset < std::min(nelems_, base + sg_size * (items_per_wi * vec_sz)); offset += sg_size)
110113
{
111-
res_[offset] = mkl_rng_dev::generate_single(distr, engine);
114+
res_[offset] = mkl_rng_dev::generate_single<DistrT, EngineT>(distr, engine);
112115
}
113116
}
114117
}

dpnp/backend/extensions/rng/device/gaussian.cpp

Lines changed: 120 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,40 @@ namespace mkl_rng_dev = oneapi::mkl::rng::device;
5757
namespace py = pybind11;
5858
namespace type_utils = dpctl::tensor::type_utils;
5959

60+
constexpr int num_methods = 2; // number of methods of gaussian distribution
61+
62+
// static mkl_rng_dev::gaussian_method get_method(const std::int8_t method) {
63+
// switch (method) {
64+
// case 0: return mkl_rng_dev::gaussian_method::by_default;
65+
// case 1: return mkl_rng_dev::gaussian_method::by_default;
66+
// default:
67+
// throw py::value_error();
68+
// }
69+
// }
70+
71+
template <typename DataT, typename Method>
72+
struct GaussianDistr
73+
{
74+
private:
75+
const DataT mean_;
76+
const DataT stddev_;
77+
78+
public:
79+
using method_type = Method;
80+
using result_type = DataT;
81+
using distr_type = typename mkl_rng_dev::gaussian<DataT, Method>;
82+
83+
GaussianDistr(const DataT mean, const DataT stddev)
84+
: mean_(mean), stddev_(stddev)
85+
{
86+
}
87+
88+
inline auto operator()(void) const
89+
{
90+
return distr_type(mean_, stddev_);
91+
}
92+
};
93+
6094
typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
6195
const std::uint32_t,
6296
const double,
@@ -65,13 +99,12 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
6599
char *,
66100
const std::vector<sycl::event> &);
67101

68-
static gaussian_impl_fn_ptr_t gaussian_dispatch_vector[dpctl_td_ns::num_types];
102+
static gaussian_impl_fn_ptr_t gaussian_dispatch_table[dpctl_td_ns::num_types][num_methods];
69103

70-
// template <typename DataT, typename Method = mkl_rng_dev::gaussian_method::by_default>
71-
template <typename DataT, unsigned int vec_sz, unsigned int items_per_wi>
104+
template <typename DataT, typename Method, unsigned int vec_sz, unsigned int items_per_wi>
72105
class gaussian_kernel;
73106

74-
template <typename DataT, typename Method = mkl_rng_dev::gaussian_method::by_default>
107+
template <typename DataT, typename Method>
75108
static sycl::event gaussian_impl(sycl::queue& exec_q,
76109
const std::uint32_t seed,
77110
const double mean_val,
@@ -98,20 +131,23 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
98131
distr_event = exec_q.submit([&](sycl::handler &cgh) {
99132
cgh.depends_on(depends);
100133

134+
using GaussianDistrT = GaussianDistr<DataT, Method>;
135+
GaussianDistrT distr(mean, stddev);
136+
101137
if (is_aligned<required_alignment>(out_ptr)) {
102138
constexpr bool enable_sg_load = true;
103-
using KernelName = gaussian_kernel<DataT, vec_sz, items_per_wi>;
139+
using KernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
104140

105141
cgh.parallel_for<KernelName>(sycl::nd_range<1>({global_size}, {local_size}),
106-
details::RngContigFunctor<DataT, DataT, Method, DataT, vec_sz, items_per_wi, enable_sg_load>(seed, mean, stddev, out, n));
142+
details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, enable_sg_load>(seed, distr, out, n));
107143
}
108144
else {
109145
constexpr bool disable_sg_load = false;
110-
using InnerKernelName = gaussian_kernel<DataT, vec_sz, items_per_wi>;
146+
using InnerKernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
111147
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
112148

113149
cgh.parallel_for<KernelName>(sycl::nd_range<1>({global_size}, {local_size}),
114-
details::RngContigFunctor<DataT, DataT, Method, DataT, vec_sz, items_per_wi, disable_sg_load>(seed, mean, stddev, out, n));
150+
details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, disable_sg_load>(seed, distr, out, n));
115151
}
116152
});
117153
} catch (oneapi::mkl::exception const &e) {
@@ -129,6 +165,7 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
129165
}
130166

131167
std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
168+
const std::uint8_t method_id,
132169
const std::uint32_t seed,
133170
const double mean,
134171
const double stddev,
@@ -166,10 +203,14 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
166203
throw std::runtime_error("Only population of contiguous array is supported.");
167204
}
168205

206+
if (method_id >= num_methods) {
207+
throw std::runtime_error("Unknown method=" + std::to_string(method_id) + " for gaussian distribution.");
208+
}
209+
169210
auto array_types = dpctl_td_ns::usm_ndarray_types();
170211
int res_type_id = array_types.typenum_to_lookup_id(res.get_typenum());
171212

172-
auto gaussian_fn = gaussian_dispatch_vector[res_type_id];
213+
auto gaussian_fn = gaussian_dispatch_table[res_type_id][method_id];
173214
if (gaussian_fn == nullptr) {
174215
throw py::value_error("No gaussian implementation defined for a required type");
175216
}
@@ -181,36 +222,95 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
181222
return std::make_pair(ht_ev, gaussian_ev);
182223
}
183224

184-
template <typename T>
225+
template <typename funcPtrT,
226+
template <typename fnT, typename D, typename S> typename factory,
227+
int _num_types,
228+
int _num_methods>
229+
// class DispatchTableBuilder : public dpctl_td_ns::DispatchTableBuilder<funcPtrT, factory, _num_types>
230+
class DispatchTableBuilder/* : public dpctl_td_ns::DispatchTableBuilder<funcPtrT, factory, _num_types>*/
231+
{
232+
private:
233+
template <typename dstTy>
234+
const std::vector<funcPtrT> row_per_method() const
235+
{
236+
std::vector<funcPtrT> per_method = {
237+
factory<funcPtrT, dstTy, mkl_rng_dev::gaussian_method::by_default>{}.get(),
238+
factory<funcPtrT, dstTy, mkl_rng_dev::gaussian_method::box_muller2>{}.get(),
239+
};
240+
assert(per_method.size() == _num_methods);
241+
return per_method;
242+
}
243+
244+
public:
245+
DispatchTableBuilder() = default;
246+
~DispatchTableBuilder() = default;
247+
248+
void populate(funcPtrT table[][_num_methods]) const
249+
{
250+
const auto map_by_dst_type = {row_per_method<bool>(),
251+
row_per_method<int8_t>(),
252+
row_per_method<uint8_t>(),
253+
row_per_method<int16_t>(),
254+
row_per_method<uint16_t>(),
255+
row_per_method<int32_t>(),
256+
row_per_method<uint32_t>(),
257+
row_per_method<int64_t>(),
258+
row_per_method<uint64_t>(),
259+
row_per_method<sycl::half>(),
260+
row_per_method<float>(),
261+
row_per_method<double>(),
262+
row_per_method<std::complex<float>>(),
263+
row_per_method<std::complex<double>>()};
264+
assert(map_by_dst_type.size() == _num_types);
265+
int dst_id = 0;
266+
for (auto &row : map_by_dst_type) {
267+
int src_id = 0;
268+
for (auto &fn_ptr : row) {
269+
table[dst_id][src_id] = fn_ptr;
270+
++src_id;
271+
}
272+
++dst_id;
273+
}
274+
}
275+
};
276+
277+
template <typename Ty, typename ArgTy, typename Method, typename argMethod>
278+
struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
279+
std::is_same_v<Method, argMethod>>
280+
{
281+
static constexpr bool is_defined = true;
282+
};
283+
284+
template <typename T, typename M>
185285
struct GaussianTypePairSupportFactory
186286
{
187287
static constexpr bool is_defined = std::disjunction<
188-
dpctl_td_ns::TypePairDefinedEntry<T, double, T, double>,
189-
dpctl_td_ns::TypePairDefinedEntry<T, float, T, float>,
288+
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::by_default>,
289+
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::box_muller2>,
290+
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::by_default>,
291+
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::box_muller2>,
190292
// fall-through
191293
dpctl_td_ns::NotDefinedEntry>::is_defined;
192294
};
193295

194-
template <typename fnT, typename T>
296+
template <typename fnT, typename T, typename M>
195297
struct GaussianContigFactory
196298
{
197299
fnT get()
198300
{
199-
if constexpr (GaussianTypePairSupportFactory<T>::is_defined) {
200-
return gaussian_impl<T>;
301+
if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
302+
return gaussian_impl<T, M>;
201303
}
202304
else {
203305
return nullptr;
204306
}
205307
}
206308
};
207309

208-
void init_gaussian_dispatch_vector(void)
310+
void init_gaussian_dispatch_table(void)
209311
{
210-
dpctl_td_ns::DispatchVectorBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory,
211-
dpctl_td_ns::num_types>
212-
contig;
213-
contig.populate_dispatch_vector(gaussian_dispatch_vector);
312+
DispatchTableBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory, dpctl_td_ns::num_types, num_methods> contig;
313+
contig.populate(gaussian_dispatch_table);
214314
}
215315
} // namespace device
216316
} // namespace rng

dpnp/backend/extensions/rng/device/gaussian.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,15 @@ namespace rng
4242
namespace device
4343
{
4444
extern std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
45+
const std::uint8_t method_id,
4546
const std::uint32_t seed,
4647
const double mean,
4748
const double stddev,
4849
const std::uint64_t n,
4950
dpctl::tensor::usm_ndarray res,
5051
const std::vector<sycl::event> &depends = {});
5152

52-
extern void init_gaussian_dispatch_vector(void);
53+
extern void init_gaussian_dispatch_table(void);
5354
} // namespace device
5455
} // namespace rng
5556
} // namespace ext

dpnp/backend/extensions/rng/device/rng_py.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ namespace py = pybind11;
4343
// populate dispatch vectors
4444
void init_dispatch_vectors(void)
4545
{
46-
rng_dev_ext::init_gaussian_dispatch_vector();
46+
// rng_dev_ext::init_gaussian_dispatch_vector();
4747
}
4848

4949
// populate dispatch tables
5050
void init_dispatch_tables(void)
5151
{
52-
// lapack_ext::init_heevd_dispatch_table();
52+
rng_dev_ext::init_gaussian_dispatch_table();
5353
}
5454

5555

@@ -81,7 +81,7 @@ PYBIND11_MODULE(_rng_dev_impl, m)
8181

8282
m.def("_gaussian", &rng_dev_ext::gaussian,
8383
"",
84-
py::arg("sycl_queue"), py::arg("seed"), py::arg("mean"), py::arg("stddev"),
84+
py::arg("sycl_queue"), py::arg("method"), py::arg("seed"), py::arg("mean"), py::arg("stddev"),
8585
py::arg("n"), py::arg("res"),
8686
py::arg("depends") = py::list());
8787
}

0 commit comments

Comments
 (0)