Skip to content

Commit 8810e37

Browse files
committed
Decoupling dispatching functionality
1 parent ad4738f commit 8810e37

File tree

5 files changed

+180
-108
lines changed

5 files changed

+180
-108
lines changed

dpnp/backend/extensions/rng/device/common_impl.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,22 @@ namespace py = pybind11;
4949

5050
namespace mkl_rng_dev = oneapi::mkl::rng::device;
5151

52-
/*! @brief Functor for unary function evaluation on contiguous array */
5352
template <typename EngineBuilderT,
54-
typename DataT,
55-
typename GaussianDistrT,
53+
typename DistributorBuilderT,
5654
unsigned int items_per_wi = 4,
5755
bool enable_sg_load = true>
5856
struct RngContigFunctor
5957
{
6058
private:
61-
// const std::uint32_t seed_;
59+
using DataT = typename DistributorBuilderT::result_type;
60+
6261
EngineBuilderT engine_;
63-
GaussianDistrT distr_;
62+
DistributorBuilderT distr_;
6463
DataT * const res_ = nullptr;
6564
const size_t nelems_;
6665

6766
public:
68-
69-
RngContigFunctor(EngineBuilderT& engine, GaussianDistrT& distr, DataT *res, const size_t n_elems)
67+
RngContigFunctor(EngineBuilderT& engine, DistributorBuilderT& distr, DataT *res, const size_t n_elems)
7068
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
7169
{
7270
}
@@ -82,7 +80,7 @@ struct RngContigFunctor
8280
using EngineT = typename EngineBuilderT::EngineType;
8381
EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
8482

85-
using DistrT = typename GaussianDistrT::distr_type;
83+
using DistrT = typename DistributorBuilderT::distr_type;
8684
DistrT distr = distr_();
8785

8886
constexpr std::size_t vec_sz = EngineT::vec_size;
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl/rng/device.hpp>
29+
30+
#include "utils/type_dispatch.hpp"
31+
32+
33+
namespace dpnp::backend::ext::rng::device::dispatch
34+
{
35+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
36+
namespace mkl_rng_dev = oneapi::mkl::rng::device;
37+
38+
template <typename Ty, typename ArgTy, typename Method, typename argMethod>
39+
struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
40+
std::is_same_v<Method, argMethod>>
41+
{
42+
static constexpr bool is_defined = true;
43+
};
44+
45+
template <typename T, typename M>
46+
struct GaussianTypePairSupportFactory
47+
{
48+
static constexpr bool is_defined = std::disjunction<
49+
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::by_default>,
50+
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::box_muller2>,
51+
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::by_default>,
52+
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::box_muller2>,
53+
// fall-through
54+
dpctl_td_ns::NotDefinedEntry>::is_defined;
55+
};
56+
} // dpnp::backend::ext::rng::device::dispatch
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl/rng/device.hpp>
29+
30+
31+
namespace dpnp::backend::ext::rng::device::dispatch
32+
{
33+
namespace mkl_rng_dev = oneapi::mkl::rng::device;
34+
35+
template <typename funcPtrT,
36+
template <typename fnT, typename E, typename T, typename M> typename factory,
37+
int _no_of_engines,
38+
int _no_of_types,
39+
int _no_of_methods>
40+
class Dispatch3DTableBuilder
41+
{
42+
private:
43+
template <typename E, typename T>
44+
const std::vector<funcPtrT> row_per_method() const
45+
{
46+
std::vector<funcPtrT> per_method = {
47+
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}.get(),
48+
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}.get(),
49+
};
50+
assert(per_method.size() == _no_of_methods);
51+
return per_method;
52+
}
53+
54+
template <typename E>
55+
auto table_per_type_and_method() const
56+
{
57+
std::vector<std::vector<funcPtrT>>
58+
table_by_type = {row_per_method<E, bool>(),
59+
row_per_method<E, int8_t>(),
60+
row_per_method<E, uint8_t>(),
61+
row_per_method<E, int16_t>(),
62+
row_per_method<E, uint16_t>(),
63+
row_per_method<E, int32_t>(),
64+
row_per_method<E, uint32_t>(),
65+
row_per_method<E, int64_t>(),
66+
row_per_method<E, uint64_t>(),
67+
row_per_method<E, sycl::half>(),
68+
row_per_method<E, float>(),
69+
row_per_method<E, double>(),
70+
row_per_method<E, std::complex<float>>(),
71+
row_per_method<E, std::complex<double>>()};
72+
assert(table_by_type.size() == _no_of_types);
73+
return table_by_type;
74+
}
75+
76+
public:
77+
Dispatch3DTableBuilder() = default;
78+
~Dispatch3DTableBuilder() = default;
79+
80+
void populate(funcPtrT table[][_no_of_types][_no_of_methods]) const
81+
{
82+
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<8>>()};
83+
assert(map_by_engine.size() == _no_of_engines);
84+
85+
std::uint16_t engine_id = 0;
86+
for (auto &table_by_type : map_by_engine) {
87+
std::uint16_t type_id = 0;
88+
for (auto &row_by_method : table_by_type) {
89+
std::uint16_t method_id = 0;
90+
for (auto &fn_ptr : row_by_method) {
91+
table[engine_id][type_id][method_id] = fn_ptr;
92+
++method_id;
93+
}
94+
++type_id;
95+
}
96+
++engine_id;
97+
}
98+
}
99+
};
100+
} // dpnp::backend::ext::rng::device::dispatch

dpnp/backend/extensions/rng/device/engine/base_builder.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class BaseBuilder {
6767
}
6868
}
6969

70-
inline auto operator()() const
70+
inline auto operator()(void) const
7171
{
7272
switch (no_of_seeds) {
7373
case 1: {

dpnp/backend/extensions/rng/device/gaussian.cpp

Lines changed: 17 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,15 @@
3333
// dpctl tensor headers
3434
#include "kernels/alignment.hpp"
3535

36-
using dpctl::tensor::kernels::alignment_utils::disabled_sg_loadstore_wrapper_krn;
37-
using dpctl::tensor::kernels::alignment_utils::is_aligned;
38-
using dpctl::tensor::kernels::alignment_utils::required_alignment;
39-
4036
#include "common_impl.hpp"
4137
#include "gaussian.hpp"
4238

4339
#include "engine/engine_base.hpp"
4440
#include "engine/engine_builder.hpp"
4541

46-
// #include "dpnp_utils.hpp"
42+
#include "dispatch/matrix.hpp"
43+
#include "dispatch/table_builder.hpp"
44+
4745

4846
namespace dpnp
4947
{
@@ -55,26 +53,31 @@ namespace rng
5553
{
5654
namespace device
5755
{
56+
namespace dpctl_krn_ns = dpctl::tensor::kernels::alignment_utils;
5857
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
5958
namespace mkl_rng_dev = oneapi::mkl::rng::device;
6059
namespace py = pybind11;
6160
namespace type_utils = dpctl::tensor::type_utils;
6261

62+
using dpctl_krn_ns::disabled_sg_loadstore_wrapper_krn;
63+
using dpctl_krn_ns::is_aligned;
64+
using dpctl_krn_ns::required_alignment;
65+
6366
constexpr int no_of_methods = 2; // number of methods of gaussian distribution
6467

6568
template <typename DataT, typename Method>
66-
struct GaussianDistr
69+
struct DistributorBuilder
6770
{
6871
private:
6972
const DataT mean_;
7073
const DataT stddev_;
7174

7275
public:
73-
using method_type = Method;
7476
using result_type = DataT;
77+
using method_type = Method;
7578
using distr_type = typename mkl_rng_dev::gaussian<DataT, Method>;
7679

77-
GaussianDistr(const DataT mean, const DataT stddev)
80+
DistributorBuilder(const DataT mean, const DataT stddev)
7881
: mean_(mean), stddev_(stddev)
7982
{
8083
}
@@ -128,23 +131,23 @@ static sycl::event gaussian_impl(engine::EngineBase *engine,
128131
EngineBuilderT eng_builder(engine);
129132
eng_builder.print(); // TODO: remove
130133

131-
using GaussianDistrT = GaussianDistr<DataT, Method>;
132-
GaussianDistrT distr(mean, stddev);
134+
using DistributorBuilderT = DistributorBuilder<DataT, Method>;
135+
DistributorBuilderT dist_builder(mean, stddev);
133136

134137
if (is_aligned<required_alignment>(out_ptr)) {
135138
constexpr bool enable_sg_load = true;
136139
using KernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
137140

138141
cgh.parallel_for<KernelName>(sycl::nd_range<1>({global_size}, {local_size}),
139-
details::RngContigFunctor<EngineBuilderT, DataT, GaussianDistrT, items_per_wi, enable_sg_load>(eng_builder, distr, out, n));
142+
details::RngContigFunctor<EngineBuilderT, DistributorBuilderT, items_per_wi, enable_sg_load>(eng_builder, dist_builder, out, n));
140143
}
141144
else {
142145
constexpr bool disable_sg_load = false;
143146
using InnerKernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
144147
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
145148

146149
cgh.parallel_for<KernelName>(sycl::nd_range<1>({global_size}, {local_size}),
147-
details::RngContigFunctor<EngineBuilderT, DataT, GaussianDistrT, items_per_wi, disable_sg_load>(eng_builder, distr, out, n));
150+
details::RngContigFunctor<EngineBuilderT, DistributorBuilderT, items_per_wi, disable_sg_load>(eng_builder, dist_builder, out, n));
148151
}
149152
});
150153
} catch (oneapi::mkl::exception const &e) {
@@ -225,97 +228,12 @@ std::pair<sycl::event, sycl::event> gaussian(engine::EngineBase *engine,
225228
return std::make_pair(ht_ev, gaussian_ev);
226229
}
227230

228-
template <typename funcPtrT,
229-
template <typename fnT, typename E, typename T, typename M> typename factory,
230-
int _no_of_engines,
231-
int _no_of_types,
232-
int _no_of_methods>
233-
class Dispatch3DTableBuilder
234-
{
235-
private:
236-
template <typename E, typename T>
237-
const std::vector<funcPtrT> row_per_method() const
238-
{
239-
std::vector<funcPtrT> per_method = {
240-
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::by_default>{}.get(),
241-
factory<funcPtrT, E, T, mkl_rng_dev::gaussian_method::box_muller2>{}.get(),
242-
};
243-
assert(per_method.size() == _no_of_methods);
244-
return per_method;
245-
}
246-
247-
template <typename E>
248-
auto table_per_type_and_method() const
249-
{
250-
std::vector<std::vector<funcPtrT>>
251-
table_by_type = {row_per_method<E, bool>(),
252-
row_per_method<E, int8_t>(),
253-
row_per_method<E, uint8_t>(),
254-
row_per_method<E, int16_t>(),
255-
row_per_method<E, uint16_t>(),
256-
row_per_method<E, int32_t>(),
257-
row_per_method<E, uint32_t>(),
258-
row_per_method<E, int64_t>(),
259-
row_per_method<E, uint64_t>(),
260-
row_per_method<E, sycl::half>(),
261-
row_per_method<E, float>(),
262-
row_per_method<E, double>(),
263-
row_per_method<E, std::complex<float>>(),
264-
row_per_method<E, std::complex<double>>()};
265-
assert(table_by_type.size() == _no_of_types);
266-
return table_by_type;
267-
}
268-
269-
public:
270-
Dispatch3DTableBuilder() = default;
271-
~Dispatch3DTableBuilder() = default;
272-
273-
void populate(funcPtrT table[][_no_of_types][_no_of_methods]) const
274-
{
275-
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<8>>()};
276-
assert(map_by_engine.size() == _no_of_engines);
277-
278-
std::uint16_t engine_id = 0;
279-
for (auto &table_by_type : map_by_engine) {
280-
std::uint16_t type_id = 0;
281-
for (auto &row_by_method : table_by_type) {
282-
std::uint16_t method_id = 0;
283-
for (auto &fn_ptr : row_by_method) {
284-
table[engine_id][type_id][method_id] = fn_ptr;
285-
++method_id;
286-
}
287-
++type_id;
288-
}
289-
++engine_id;
290-
}
291-
}
292-
};
293-
294-
template <typename Ty, typename ArgTy, typename Method, typename argMethod>
295-
struct TypePairDefinedEntry : std::bool_constant<std::is_same_v<Ty, ArgTy> &&
296-
std::is_same_v<Method, argMethod>>
297-
{
298-
static constexpr bool is_defined = true;
299-
};
300-
301-
template <typename T, typename M>
302-
struct GaussianTypePairSupportFactory
303-
{
304-
static constexpr bool is_defined = std::disjunction<
305-
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::by_default>,
306-
TypePairDefinedEntry<T, double, M, mkl_rng_dev::gaussian_method::box_muller2>,
307-
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::by_default>,
308-
TypePairDefinedEntry<T, float, M, mkl_rng_dev::gaussian_method::box_muller2>,
309-
// fall-through
310-
dpctl_td_ns::NotDefinedEntry>::is_defined;
311-
};
312-
313231
template <typename fnT, typename E, typename T, typename M>
314232
struct GaussianContigFactory
315233
{
316234
fnT get()
317235
{
318-
if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
236+
if constexpr (dispatch::GaussianTypePairSupportFactory<T, M>::is_defined) {
319237
return gaussian_impl<E, T, M>;
320238
}
321239
else {
@@ -326,7 +244,7 @@ struct GaussianContigFactory
326244

327245
void init_gaussian_dispatch_table(void)
328246
{
329-
Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
247+
dispatch::Dispatch3DTableBuilder<gaussian_impl_fn_ptr_t, GaussianContigFactory, engine::no_of_engines, dpctl_td_ns::num_types, no_of_methods> contig;
330248
contig.populate(gaussian_dispatch_table);
331249
}
332250
} // namespace device

0 commit comments

Comments
 (0)