31
31
#include < oneapi/mkl/rng/device.hpp>
32
32
33
33
// dpctl tensor headers
34
- #include " kernels/alignment.hpp"
35
- #include " utils/offset_utils.hpp"
34
+ // #include "utils/offset_utils.hpp"
36
35
37
36
namespace dpnp
38
37
{
@@ -48,9 +47,6 @@ namespace details
48
47
{
49
48
namespace py = pybind11;
50
49
51
- using dpctl::tensor::kernels::alignment_utils::is_aligned;
52
- using dpctl::tensor::kernels::alignment_utils::required_alignment;
53
-
54
50
namespace mkl_rng_dev = oneapi::mkl::rng::device;
55
51
56
52
/* ! @brief Functor for unary function evaluation on contiguous array */
@@ -67,7 +63,7 @@ struct RngContigFunctor
67
63
const std::uint32_t seed_;
68
64
const DataT mean_;
69
65
const DataT stddev_;
70
- ResT *res_ = nullptr ;
66
+ ResT * const res_ = nullptr ;
71
67
const size_t nelems_;
72
68
73
69
public:
@@ -84,10 +80,10 @@ struct RngContigFunctor
84
80
const std::uint8_t sg_size = sg.get_local_range ()[0 ];
85
81
const std::uint8_t max_sg_size = sg.get_max_local_range ()[0 ];
86
82
87
- auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id);
83
+ auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id); // offset is questionable...
88
84
mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
89
85
90
- if (enable_sg_load) {
86
+ if constexpr (enable_sg_load) {
91
87
const size_t base = items_per_wi * vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
92
88
93
89
if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
@@ -118,38 +114,38 @@ struct RngContigFunctor
118
114
}
119
115
};
120
116
121
- template <typename DataT,
122
- typename ResT = DataT,
123
- typename Method = mkl_rng_dev::gaussian_method::by_default,
124
- typename IndexerT = ResT,
125
- typename UnaryOpT = ResT>
126
- struct RngStridedFunctor
127
- {
128
- private:
129
- const std::uint32_t seed_;
130
- const double mean_;
131
- const double stddev_;
132
- ResT *res_ = nullptr ;
133
- IndexerT out_indexer_;
134
-
135
- public:
136
- RngStridedFunctor (const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
137
- : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
138
- {
139
- }
140
-
141
- void operator ()(sycl::id<1 > wid) const
142
- {
143
- const auto res_offset = out_indexer_ (wid.get (0 ));
144
-
145
- // UnaryOpT op{};
146
-
147
- auto engine = mkl_rng_dev::mrg32k3a (seed_);
148
- mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
149
-
150
- res_[res_offset] = mkl_rng_dev::generate (distr, engine);
151
- }
152
- };
117
+ // template <typename DataT,
118
+ // typename ResT = DataT,
119
+ // typename Method = mkl_rng_dev::gaussian_method::by_default,
120
+ // typename IndexerT = ResT,
121
+ // typename UnaryOpT = ResT>
122
+ // struct RngStridedFunctor
123
+ // {
124
+ // private:
125
+ // const std::uint32_t seed_;
126
+ // const double mean_;
127
+ // const double stddev_;
128
+ // ResT *res_ = nullptr;
129
+ // IndexerT out_indexer_;
130
+
131
+ // public:
132
+ // RngStridedFunctor(const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
133
+ // : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
134
+ // {
135
+ // }
136
+
137
+ // void operator()(sycl::id<1> wid) const
138
+ // {
139
+ // const auto res_offset = out_indexer_(wid.get(0));
140
+
141
+ // // UnaryOpT op{};
142
+
143
+ // auto engine = mkl_rng_dev::mrg32k3a(seed_);
144
+ // mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
145
+
146
+ // res_[res_offset] = mkl_rng_dev::generate(distr, engine);
147
+ // }
148
+ // };
153
149
} // namespace details
154
150
} // namespace device
155
151
} // namespace rng
0 commit comments