@@ -91,7 +91,26 @@ struct GaussianDistr
91
91
}
92
92
};
93
93
94
- typedef sycl::event (*gaussian_impl_fn_ptr_t )(sycl::queue &,
94
+ template <typename EngineBase, typename MklEngineT>
95
+ struct EngineDistr
96
+ {
97
+ private:
98
+ EngineBase *engine_;
99
+
100
+ public:
101
+ using engine_type = MklEngineT;
102
+
103
+ EngineDistr (EngineBase *engine) : engine_(engine)
104
+ {
105
+ }
106
+
107
+ inline auto operator ()(void ) const
108
+ {
109
+ return MklEngineT (engine_->seed_ , engine_->offset_ );
110
+ }
111
+ };
112
+
113
+ typedef sycl::event (*gaussian_impl_fn_ptr_t )(EngineBase *engine,
95
114
const std::uint32_t ,
96
115
const double ,
97
116
const double ,
@@ -101,25 +120,26 @@ typedef sycl::event (*gaussian_impl_fn_ptr_t)(sycl::queue &,
101
120
102
121
static gaussian_impl_fn_ptr_t gaussian_dispatch_table[dpctl_td_ns::num_types][num_methods];
103
122
104
- template <typename DataT, typename Method, unsigned int vec_sz , unsigned int items_per_wi>
123
+ template <typename EngineT, typename DataT, typename Method , unsigned int items_per_wi>
105
124
class gaussian_kernel ;
106
125
107
- template <typename DataT, typename Method>
108
- static sycl::event gaussian_impl (sycl::queue& exec_q ,
126
+ template <typename EngineT, typename DataT, typename Method>
127
+ static sycl::event gaussian_impl (EngineBase *engine ,
109
128
const std::uint32_t seed,
110
129
const double mean_val,
111
130
const double stddev_val,
112
131
const std::uint64_t n,
113
132
char *out_ptr,
114
133
const std::vector<sycl::event> &depends)
115
134
{
135
+ auto exec_q = engine->get_queue ();
116
136
type_utils::validate_type_for_device<DataT>(exec_q);
117
137
118
138
DataT *out = reinterpret_cast <DataT *>(out_ptr);
119
139
DataT mean = static_cast <DataT>(mean_val);
120
140
DataT stddev = static_cast <DataT>(stddev_val);
121
141
122
- constexpr std::size_t vec_sz = 8 ;
142
+ constexpr std::size_t vec_sz = EngineT::vec_size ;
123
143
constexpr std::size_t items_per_wi = 4 ;
124
144
constexpr std::size_t local_size = 256 ;
125
145
const std::size_t wg_items = local_size * vec_sz * items_per_wi;
@@ -131,23 +151,28 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
131
151
distr_event = exec_q.submit ([&](sycl::handler &cgh) {
132
152
cgh.depends_on (depends);
133
153
154
+ using EngineDistrT = EngineDistr<MRG32k3a, EngineT>;
155
+ EngineDistrT eng (static_cast <MRG32k3a*>(engine));
156
+
157
+ // EngineT engine = EngineT(seed, 0);
158
+
134
159
using GaussianDistrT = GaussianDistr<DataT, Method>;
135
160
GaussianDistrT distr (mean, stddev);
136
161
137
162
if (is_aligned<required_alignment>(out_ptr)) {
138
163
constexpr bool enable_sg_load = true ;
139
- using KernelName = gaussian_kernel<DataT, Method, vec_sz , items_per_wi>;
164
+ using KernelName = gaussian_kernel<EngineT, DataT, Method , items_per_wi>;
140
165
141
166
cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
142
- details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, enable_sg_load>(seed , distr, out, n));
167
+ details::RngContigFunctor<EngineDistrT, DataT, GaussianDistrT, items_per_wi, enable_sg_load>(eng , distr, out, n));
143
168
}
144
169
else {
145
170
constexpr bool disable_sg_load = false ;
146
- using InnerKernelName = gaussian_kernel<DataT, Method, vec_sz , items_per_wi>;
171
+ using InnerKernelName = gaussian_kernel<EngineT, DataT, Method , items_per_wi>;
147
172
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
148
173
149
174
cgh.parallel_for <KernelName>(sycl::nd_range<1 >({global_size}, {local_size}),
150
- details::RngContigFunctor<DataT, GaussianDistrT, DataT, DataT, vec_sz, items_per_wi, disable_sg_load>(seed , distr, out, n));
175
+ details::RngContigFunctor<EngineDistrT, DataT, GaussianDistrT, items_per_wi, disable_sg_load>(eng , distr, out, n));
151
176
}
152
177
});
153
178
} catch (oneapi::mkl::exception const &e) {
@@ -164,7 +189,7 @@ static sycl::event gaussian_impl(sycl::queue& exec_q,
164
189
return distr_event;
165
190
}
166
191
167
- std::pair<sycl::event, sycl::event> gaussian (sycl::queue exec_q ,
192
+ std::pair<sycl::event, sycl::event> gaussian (EngineBase *engine ,
168
193
const std::uint8_t method_id,
169
194
const std::uint32_t seed,
170
195
const double mean,
@@ -173,6 +198,9 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
173
198
dpctl::tensor::usm_ndarray res,
174
199
const std::vector<sycl::event> &depends)
175
200
{
201
+ std::cout << engine->print () << std::endl;
202
+ auto exec_q = engine->get_queue ();
203
+
176
204
const int res_nd = res.get_ndim ();
177
205
const py::ssize_t *res_shape = res.get_shape_raw ();
178
206
@@ -216,7 +244,7 @@ std::pair<sycl::event, sycl::event> gaussian(sycl::queue exec_q,
216
244
}
217
245
218
246
char *res_data = res.get_data ();
219
- sycl::event gaussian_ev = gaussian_fn (exec_q , seed, mean, stddev, n, res_data, depends);
247
+ sycl::event gaussian_ev = gaussian_fn (engine , seed, mean, stddev, n, res_data, depends);
220
248
221
249
sycl::event ht_ev = dpctl::utils::keep_args_alive (exec_q, {res}, {gaussian_ev});
222
250
return std::make_pair (ht_ev, gaussian_ev);
@@ -299,7 +327,7 @@ struct GaussianContigFactory
299
327
fnT get ()
300
328
{
301
329
if constexpr (GaussianTypePairSupportFactory<T, M>::is_defined) {
302
- return gaussian_impl<T, M>;
330
+ return gaussian_impl<mkl_rng_dev::mrg32k3a< 8 >, T, M>;
303
331
}
304
332
else {
305
333
return nullptr ;
0 commit comments