Skip to content

Commit 9cff286

Browse files
committed
Add pybind class for engines
1 parent 84074fd commit 9cff286

File tree

4 files changed

+126
-39
lines changed

4 files changed

+126
-39
lines changed

dpnp/backend/extensions/rng/device/common_impl.hpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,42 +50,47 @@ namespace py = pybind11;
5050
namespace mkl_rng_dev = oneapi::mkl::rng::device;
5151

5252
/*! @brief Functor for unary function evaluation on contiguous array */
53-
template <typename DataT,
53+
template <typename EngineDistrT,
54+
typename DataT,
5455
typename GaussianDistrT,
55-
typename ResT = DataT,
56-
typename UnaryOperatorT = ResT,
57-
unsigned int vec_sz = 8,
5856
unsigned int items_per_wi = 4,
5957
bool enable_sg_load = true>
6058
struct RngContigFunctor
6159
{
6260
private:
63-
const std::uint32_t seed_;
61+
// const std::uint32_t seed_;
62+
EngineDistrT engine_;
6463
GaussianDistrT distr_;
65-
ResT * const res_ = nullptr;
64+
DataT * const res_ = nullptr;
6665
const size_t nelems_;
6766

6867
public:
6968

70-
RngContigFunctor(const std::uint32_t seed, GaussianDistrT& distr, ResT *res, const size_t n_elems)
71-
: seed_(seed), distr_(distr), res_(res), nelems_(n_elems)
69+
RngContigFunctor(EngineDistrT& engine, GaussianDistrT& distr, DataT *res, const size_t n_elems)
70+
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
7271
{
7372
}
7473

7574
void operator()(sycl::nd_item<1> nd_it) const
7675
{
77-
auto global_id = nd_it.get_global_id();
76+
// auto global_id = nd_it.get_global_id();
77+
78+
// constexpr std::size_t vec_sz = EngineT::vec_size;
7879

7980
auto sg = nd_it.get_sub_group();
8081
const std::uint8_t sg_size = sg.get_local_range()[0];
8182
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
8283

83-
using EngineT = typename mkl_rng_dev::mrg32k3a<vec_sz>;
84-
auto engine = EngineT(seed_, nelems_ * global_id); // offset is questionable...
84+
// auto engine = EngineT(seed_, nelems_ * global_id); // offset is questionable...
85+
86+
using EngineT = typename EngineDistrT::engine_type;
87+
EngineT engine = engine_();
8588

8689
using DistrT = typename GaussianDistrT::distr_type;
8790
DistrT distr = distr_();
8891

92+
constexpr std::size_t vec_sz = EngineT::vec_size;
93+
8994
if constexpr (enable_sg_load) {
9095
const size_t base = items_per_wi * vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
9196

dpnp/backend/extensions/rng/device/gaussian.cpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,26 @@ struct GaussianDistr
9191
}
9292
};
9393

94-
typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
94+
template <typename EngineBase, typename MklEngineT>
95+
struct EngineDistr
96+
{
97+
private:
98+
EngineBase *engine_;
99+
100+
public:
101+
using engine_type = MklEngineT;
102+
103+
EngineDistr(EngineBase *engine) : engine_(engine)
104+
{
105+
}
106+
107+
inline auto operator()(void) const
108+
{
109+
return MklEngineT(engine_->seed_, engine_->offset_);
110+
}
111+
};
112+
113+
typedef sycl::event (*gaussian_impl_fn_ptr_t)(EngineBase *engine,
95114
const std::uint32_t,
96115
const double,
97116
const double,
@@ -101,25 +120,26 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
101120

102121
static gaussian_impl_fn_ptr_t gaussian_dispatch_table[dpctl_td_ns::num_types][num_methods];
103122

104-
template <typename DataT, typename Method, unsigned int vec_sz, unsigned int items_per_wi>
123+
template <typename EngineT, typename DataT, typename Method, unsigned int items_per_wi>
105124
class gaussian_kernel;
106125

107-
template <typename DataT, typename Method>
108-
static sycl::event gaussian_impl(sycl::queue& exec_q,
126+
template <typename EngineT, typename DataT, typename Method>
127+
static sycl::event gaussian_impl(EngineBase *engine,
109128
const std::uint32_t seed,
110129
const double mean_val,
111130
const double stddev_val,
112131
const std::uint64_t n,
113132
char *out_ptr,
114133
const std::vector<sycl::event> &depends)
115134
{
135+
auto exec_q = engine->get_queue();
116136
type_utils::validate_type_for_device<DataT>(exec_q);
117137

118138
DataT *out = reinterpret_cast<DataT *>(out_ptr);
119139
DataT mean = static_cast<DataT>(mean_val);
120140
DataT stddev = static_cast<DataT>(stddev_val);
121141

122-
constexpr std::size_t vec_sz = 8;
142+
constexpr std::size_t vec_sz = EngineT::vec_size;
123143
constexpr std::size_t items_per_wi = 4;
124144
constexpr std::size_t local_size = 256;
125145
const std::size_t wg_items = local_size * vec_sz * items_per_wi;
@@ -131,23 +151,28 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
131151
distr_event = exec_q.submit([&](sycl::handler &cgh) {
132152
cgh.depends_on(depends);
133153

154+
using EngineDistrT = EngineDistr<MRG32k3a, EngineT>;
155+
EngineDistrT eng(static_cast<MRG32k3a*>(engine));
156+
157+
// EngineT engine = EngineT(seed, 0);
158+
134159
using GaussianDistrT = GaussianDistr<DataT, Method>;
135160
GaussianDistrT distr(mean, stddev);
136161

137162
if (is_aligned<required_alignment>(out_ptr)) {
138163
constexpr bool enable_sg_load = true;
139-
using KernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
164+
using KernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
140165

141166
cgh.parallel_for<KernelName>(sycl::nd_range<1>({global_size}, {local_size}),
142-
details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, enable_sg_load>(seed, distr, out, n));
167+
details::RngContigFunctor<EngineDistrT, DataT, GaussianDistrT, items_per_wi, enable_sg_load>(eng, distr, out, n));
143168
}
144169
else {
145170
constexpr bool disable_sg_load = false;
146-
using InnerKernelName = gaussian_kernel<DataT, Method, vec_sz, items_per_wi>;
171+
using InnerKernelName = gaussian_kernel<EngineT, DataT, Method, items_per_wi>;
147172
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
148173

149174
cgh.parallel_for<KernelName>(sycl::nd_range<1>({global_size}, {local_size}),
150-
details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, disable_sg_load>(seed, distr, out, n));
175+
details::RngContigFunctor<EngineDistrT, DataT, GaussianDistrT, items_per_wi, disable_sg_load>(eng, distr, out, n));
151176
}
152177
});
153178
} catch (oneapi::mkl::exception const &e) {
@@ -164,7 +189,7 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
164189
return distr_event;
165190
}
166191

167-
std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
192+
std::pair<sycl::event, sycl::event> gaussian(EngineBase *engine,
168193
const std::uint8_t method_id,
169194
const std::uint32_t seed,
170195
const double mean,
@@ -173,6 +198,9 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
173198
dpctl::tensor::usm_ndarray res,
174199
const std::vector<sycl::event> &depends)
175200
{
201+
std::cout << engine->print() << std::endl;
202+
auto exec_q = engine->get_queue();
203+
176204
const int res_nd = res.get_ndim();
177205
const py::ssize_t *res_shape = res.get_shape_raw();
178206

@@ -216,7 +244,7 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
216244
}
217245

218246
char *res_data = res.get_data();
219-
sycl::event gaussian_ev = gaussian_fn(exec_q, seed, mean, stddev, n, res_data, depends);
247+
sycl::event gaussian_ev = gaussian_fn(engine, seed, mean, stddev, n, res_data, depends);
220248

221249
sycl::event ht_ev = dpctl::utils::keep_args_alive(exec_q, {res}, {gaussian_ev});
222250
return std::make_pair(ht_ev, gaussian_ev);
@@ -299,7 +327,7 @@ struct GaussianContigFactory
299327
fnT get()
300328
{
301329
if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
302-
return gaussian_impl<T, M>;
330+
return gaussian_impl<mkl_rng_dev::mrg32k3a<8>, T, M>;
303331
}
304332
else {
305333
return nullptr;

dpnp/backend/extensions/rng/device/gaussian.hpp

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,43 @@
3131

3232
#include <dpctl4pybind11.hpp>
3333

34-
namespace dpnp
35-
{
36-
namespace backend
37-
{
38-
namespace ext
39-
{
40-
namespace rng
41-
{
42-
namespace device
34+
class EngineBase {
35+
public:
36+
virtual ~EngineBase() {}
37+
virtual sycl::queue get_queue() = 0;
38+
virtual std::string print() = 0;
39+
// auto get_engine() {
40+
// return nullptr;
41+
// }
42+
};
43+
44+
class MRG32k3a : public EngineBase {
45+
public:
46+
sycl::queue q_;
47+
const std::uint32_t seed_;
48+
const std::uint64_t offset_;
49+
50+
// public:
51+
MRG32k3a(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q), seed_(seed), offset_(offset) {}
52+
53+
sycl::queue get_queue() override {
54+
return q_;
55+
}
56+
57+
std::string print() override {
58+
return "seed = " + std::to_string(seed_) + ", offset = " + std::to_string(offset_);
59+
}
60+
61+
// auto get_engine() override {
62+
// return oneapi::mkl::rng::device::mrg32k3a<8>(seed_, offset_);
63+
// }
64+
65+
// using engine_type = oneapi::mkl::rng::device::mrg32k3a<8>;
66+
};
67+
68+
namespace dpnp::backend::ext::rng::device
4369
{
44-
extern std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
70+
extern std::pair<sycl::event, sycl::event> gaussian(EngineBase *engine,
4571
const std::uint8_t method_id,
4672
const std::uint32_t seed,
4773
const double mean,
@@ -51,8 +77,4 @@ extern std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
5177
const std::vector<sycl::event> &depends = {});
5278

5379
extern void init_gaussian_dispatch_table(void);
54-
} // namespace device
55-
} // namespace rng
56-
} // namespace ext
57-
} // namespace backend
58-
} // namespace dpnp
80+
} // namespace dpnp::backend::ext::rng::device

dpnp/backend/extensions/rng/device/rng_py.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,29 @@ void init_dispatch_tables(void)
5252
rng_dev_ext::init_gaussian_dispatch_table();
5353
}
5454

55+
class PyEngineBase : public EngineBase {
56+
public:
57+
/* Inherit the constructors */
58+
using EngineBase::EngineBase;
59+
60+
/* Trampoline (need one for each virtual function) */
61+
sycl::queue get_queue() override {
62+
PYBIND11_OVERRIDE_PURE(
63+
sycl::queue, /* Return type */
64+
EngineBase, /* Parent class */
65+
get_queue, /* Name of function in C++ (must match Python name) */
66+
);
67+
}
68+
69+
std::string print() override {
70+
PYBIND11_OVERRIDE_PURE(
71+
std::string, /* Return type */
72+
EngineBase, /* Parent class */
73+
print, /* Name of function in C++ (must match Python name) */
74+
);
75+
}
76+
};
77+
5578

5679
PYBIND11_MODULE(_rng_dev_impl, m)
5780
{
@@ -79,9 +102,18 @@ PYBIND11_MODULE(_rng_dev_impl, m)
79102
// py::arg("eig_vecs"), py::arg("eig_vals"),
80103
// py::arg("depends") = py::list());
81104

105+
py::class_<EngineBase, PyEngineBase /* <--- trampoline */>(m, "EngineBase")
106+
.def(py::init<>())
107+
.def("print", &EngineBase::print);
108+
109+
py::class_<MRG32k3a, EngineBase>(m, "MRG32k3a")
110+
.def(py::init<sycl::queue &, std::uint32_t, std::uint64_t>());
111+
112+
82113
m.def("_gaussian", &rng_dev_ext::gaussian,
83114
"",
84-
py::arg("sycl_queue"), py::arg("method"), py::arg("seed"), py::arg("mean"), py::arg("stddev"),
115+
py::arg("engine"),
116+
py::arg("method"), py::arg("seed"), py::arg("mean"), py::arg("stddev"),
85117
py::arg("n"), py::arg("res"),
86118
py::arg("depends") = py::list());
87119
}

0 commit comments

Comments
 (0)