Skip to content

Commit 114ab1d

Browse files
committed
Disabled strided implementation
1 parent 1754216 commit 114ab1d

File tree

2 files changed

+59
-235
lines changed

2 files changed

+59
-235
lines changed

dpnp/backend/extensions/rng/device/common_impl.hpp

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
#include <oneapi/mkl/rng/device.hpp>
3232

3333
// dpctl tensor headers
34-
#include "kernels/alignment.hpp"
35-
#include "utils/offset_utils.hpp"
34+
// #include "utils/offset_utils.hpp"
3635

3736
namespace dpnp
3837
{
@@ -48,9 +47,6 @@ namespace details
4847
{
4948
namespace py = pybind11;
5049

51-
using dpctl::tensor::kernels::alignment_utils::is_aligned;
52-
using dpctl::tensor::kernels::alignment_utils::required_alignment;
53-
5450
namespace mkl_rng_dev = oneapi::mkl::rng::device;
5551

5652
/*! @brief Functor for unary function evaluation on contiguous array */
@@ -67,7 +63,7 @@ struct RngContigFunctor
6763
const std::uint32_t seed_;
6864
const DataT mean_;
6965
const DataT stddev_;
70-
ResT *res_ = nullptr;
66+
ResT * const res_ = nullptr;
7167
const size_t nelems_;
7268

7369
public:
@@ -84,10 +80,10 @@ struct RngContigFunctor
8480
const std::uint8_t sg_size = sg.get_local_range()[0];
8581
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
8682

87-
auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id);
83+
auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id); // offset is questionable...
8884
mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
8985

90-
if (enable_sg_load) {
86+
if constexpr (enable_sg_load) {
9187
const size_t base = items_per_wi * vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
9288

9389
if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
@@ -118,38 +114,38 @@ struct RngContigFunctor
118114
}
119115
};
120116

121-
template <typename DataT,
122-
typename ResT = DataT,
123-
typename Method = mkl_rng_dev::gaussian_method::by_default,
124-
typename IndexerT = ResT,
125-
typename UnaryOpT = ResT>
126-
struct RngStridedFunctor
127-
{
128-
private:
129-
const std::uint32_t seed_;
130-
const double mean_;
131-
const double stddev_;
132-
ResT *res_ = nullptr;
133-
IndexerT out_indexer_;
134-
135-
public:
136-
RngStridedFunctor(const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
137-
: seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
138-
{
139-
}
140-
141-
void operator()(sycl::id<1> wid) const
142-
{
143-
const auto res_offset = out_indexer_(wid.get(0));
144-
145-
// UnaryOpT op{};
146-
147-
auto engine = mkl_rng_dev::mrg32k3a(seed_);
148-
mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
149-
150-
res_[res_offset] = mkl_rng_dev::generate(distr, engine);
151-
}
152-
};
117+
// template <typename DataT,
118+
// typename ResT = DataT,
119+
// typename Method = mkl_rng_dev::gaussian_method::by_default,
120+
// typename IndexerT = ResT,
121+
// typename UnaryOpT = ResT>
122+
// struct RngStridedFunctor
123+
// {
124+
// private:
125+
// const std::uint32_t seed_;
126+
// const double mean_;
127+
// const double stddev_;
128+
// ResT *res_ = nullptr;
129+
// IndexerT out_indexer_;
130+
131+
// public:
132+
// RngStridedFunctor(const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
133+
// : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
134+
// {
135+
// }
136+
137+
// void operator()(sycl::id<1> wid) const
138+
// {
139+
// const auto res_offset = out_indexer_(wid.get(0));
140+
141+
// // UnaryOpT op{};
142+
143+
// auto engine = mkl_rng_dev::mrg32k3a(seed_);
144+
// mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
145+
146+
// res_[res_offset] = mkl_rng_dev::generate(distr, engine);
147+
// }
148+
// };
153149
} // namespace details
154150
} // namespace device
155151
} // namespace rng

0 commit comments

Comments
 (0)