1
+ // *****************************************************************************
2
+ // Copyright (c) 2024, Intel Corporation
3
+ // All rights reserved.
4
+ //
5
+ // Redistribution and use in source and binary forms, with or without
6
+ // modification, are permitted provided that the following conditions are met:
7
+ // - Redistributions of source code must retain the above copyright notice,
8
+ // this list of conditions and the following disclaimer.
9
+ // - Redistributions in binary form must reproduce the above copyright notice,
10
+ // this list of conditions and the following disclaimer in the documentation
11
+ // and/or other materials provided with the distribution.
12
+ //
13
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
+ // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
+ // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
+ // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
+ // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
+ // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
+ // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
+ // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
+ // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
+ // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
+ // THE POSSIBILITY OF SUCH DAMAGE.
24
+ // *****************************************************************************
25
+
26
+ #pragma once
27
+
28
+ #include < pybind11/pybind11.h>
29
+
30
+ #include < sycl/sycl.hpp>
31
+ #include < oneapi/mkl/rng/device.hpp>
32
+
33
+ // dpctl tensor headers
34
+ #include " kernels/alignment.hpp"
35
+ #include " utils/offset_utils.hpp"
36
+
37
+ namespace dpnp
38
+ {
39
+ namespace backend
40
+ {
41
+ namespace ext
42
+ {
43
+ namespace rng
44
+ {
45
+ namespace device
46
+ {
47
+ namespace details
48
+ {
49
+ namespace py = pybind11;
50
+
51
+ using dpctl::tensor::kernels::alignment_utils::is_aligned;
52
+ using dpctl::tensor::kernels::alignment_utils::required_alignment;
53
+
54
+ namespace mkl_rng_dev = oneapi::mkl::rng::device;
55
+
56
+ /* ! @brief Functor for unary function evaluation on contiguous array */
57
+ template <typename DataT,
58
+ typename ResT = DataT,
59
+ typename Method = mkl_rng_dev::gaussian_method::by_default,
60
+ typename UnaryOperatorT = ResT,
61
+ unsigned int vec_sz = 8 ,
62
+ unsigned int items_per_wi = 4 ,
63
+ bool enable_sg_load = true >
64
+ struct RngContigFunctor
65
+ {
66
+ private:
67
+ const std::uint32_t seed_;
68
+ const DataT mean_;
69
+ const DataT stddev_;
70
+ ResT *res_ = nullptr ;
71
+ const size_t nelems_;
72
+
73
+ public:
74
+ RngContigFunctor (const std::uint32_t seed, const DataT mean, const DataT stddev, ResT *res, const size_t n_elems)
75
+ : seed_(seed), mean_(mean), stddev_(stddev), res_(res), nelems_(n_elems)
76
+ {
77
+ }
78
+
79
+ void operator ()(sycl::nd_item<1 > nd_it) const
80
+ {
81
+ auto global_id = nd_it.get_global_id ();
82
+
83
+ auto sg = nd_it.get_sub_group ();
84
+ const std::uint8_t sg_size = sg.get_local_range ()[0 ];
85
+ const std::uint8_t max_sg_size = sg.get_max_local_range ()[0 ];
86
+
87
+ auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id);
88
+ mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
89
+
90
+ if (enable_sg_load) {
91
+ const size_t base = items_per_wi * vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
92
+
93
+ if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
94
+ #pragma unroll
95
+ for (std::uint16_t it = 0 ; it < items_per_wi * vec_sz; it += vec_sz) {
96
+ size_t offset = base + static_cast <size_t >(it) * static_cast <size_t >(sg_size);
97
+ auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
98
+
99
+ sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate (distr, engine);
100
+ sg.store <vec_sz>(out_multi_ptr, rng_val_vec);
101
+ }
102
+ }
103
+ else {
104
+ for (size_t offset = base + sg.get_local_id ()[0 ]; offset < nelems_; offset += sg_size) {
105
+ res_[offset] = mkl_rng_dev::generate_single (distr, engine);
106
+ }
107
+ }
108
+ }
109
+ else {
110
+ size_t base = nd_it.get_global_linear_id ();
111
+
112
+ base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
113
+ for (size_t offset = base; offset < std::min (nelems_, base + sg_size * (items_per_wi * vec_sz)); offset += sg_size)
114
+ {
115
+ res_[offset] = mkl_rng_dev::generate_single (distr, engine);
116
+ }
117
+ }
118
+ }
119
+ };
120
+
121
+ template <typename DataT,
122
+ typename ResT = DataT,
123
+ typename Method = mkl_rng_dev::gaussian_method::by_default,
124
+ typename IndexerT = ResT,
125
+ typename UnaryOpT = ResT>
126
+ struct RngStridedFunctor
127
+ {
128
+ private:
129
+ const std::uint32_t seed_;
130
+ const double mean_;
131
+ const double stddev_;
132
+ ResT *res_ = nullptr ;
133
+ IndexerT out_indexer_;
134
+
135
+ public:
136
+ RngStridedFunctor (const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
137
+ : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
138
+ {
139
+ }
140
+
141
+ void operator ()(sycl::id<1 > wid) const
142
+ {
143
+ const auto res_offset = out_indexer_ (wid.get (0 ));
144
+
145
+ // UnaryOpT op{};
146
+
147
+ auto engine = mkl_rng_dev::mrg32k3a (seed_);
148
+ mkl_rng_dev::gaussian<DataT, Method> distr (mean_, stddev_);
149
+
150
+ res_[res_offset] = mkl_rng_dev::generate (distr, engine);
151
+ }
152
+ };
153
+ } // namespace details
154
+ } // namespace device
155
+ } // namespace rng
156
+ } // namespace ext
157
+ } // namespace backend
158
+ } // namespace dpnp
0 commit comments