Skip to content

Commit 1754216

Browse files
committed
Added strided kernel
1 parent a722089 commit 1754216

File tree

2 files changed

+366
-75
lines changed

2 files changed

+366
-75
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
#include <sycl/sycl.hpp>
31+
#include <oneapi/mkl/rng/device.hpp>
32+
33+
// dpctl tensor headers
34+
#include "kernels/alignment.hpp"
35+
#include "utils/offset_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace rng
44+
{
45+
namespace device
46+
{
47+
namespace details
48+
{
49+
namespace py = pybind11;
50+
51+
using dpctl::tensor::kernels::alignment_utils::is_aligned;
52+
using dpctl::tensor::kernels::alignment_utils::required_alignment;
53+
54+
namespace mkl_rng_dev = oneapi::mkl::rng::device;
55+
56+
/*! @brief Functor for unary function evaluation on contiguous array */
57+
template <typename DataT,
58+
typename ResT = DataT,
59+
typename Method = mkl_rng_dev::gaussian_method::by_default,
60+
typename UnaryOperatorT = ResT,
61+
unsigned int vec_sz = 8,
62+
unsigned int items_per_wi = 4,
63+
bool enable_sg_load = true>
64+
struct RngContigFunctor
65+
{
66+
private:
67+
const std::uint32_t seed_;
68+
const DataT mean_;
69+
const DataT stddev_;
70+
ResT *res_ = nullptr;
71+
const size_t nelems_;
72+
73+
public:
74+
RngContigFunctor(const std::uint32_t seed, const DataT mean, const DataT stddev, ResT *res, const size_t n_elems)
75+
: seed_(seed), mean_(mean), stddev_(stddev), res_(res), nelems_(n_elems)
76+
{
77+
}
78+
79+
void operator()(sycl::nd_item<1> nd_it) const
80+
{
81+
auto global_id = nd_it.get_global_id();
82+
83+
auto sg = nd_it.get_sub_group();
84+
const std::uint8_t sg_size = sg.get_local_range()[0];
85+
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
86+
87+
auto engine = mkl_rng_dev::mrg32k3a<vec_sz>(seed_, nelems_ * global_id);
88+
mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
89+
90+
if (enable_sg_load) {
91+
const size_t base = items_per_wi * vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
92+
93+
if ((sg_size == max_sg_size) && (base + items_per_wi * vec_sz * sg_size < nelems_)) {
94+
#pragma unroll
95+
for (std::uint16_t it = 0; it < items_per_wi * vec_sz; it += vec_sz) {
96+
size_t offset = base + static_cast<size_t>(it) * static_cast<size_t>(sg_size);
97+
auto out_multi_ptr = sycl::address_space_cast<sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res_[offset]);
98+
99+
sycl::vec<DataT, vec_sz> rng_val_vec = mkl_rng_dev::generate(distr, engine);
100+
sg.store<vec_sz>(out_multi_ptr, rng_val_vec);
101+
}
102+
}
103+
else {
104+
for (size_t offset = base + sg.get_local_id()[0]; offset < nelems_; offset += sg_size) {
105+
res_[offset] = mkl_rng_dev::generate_single(distr, engine);
106+
}
107+
}
108+
}
109+
else {
110+
size_t base = nd_it.get_global_linear_id();
111+
112+
base = (base / sg_size) * sg_size * items_per_wi * vec_sz + (base % sg_size);
113+
for (size_t offset = base; offset < std::min(nelems_, base + sg_size * (items_per_wi * vec_sz)); offset += sg_size)
114+
{
115+
res_[offset] = mkl_rng_dev::generate_single(distr, engine);
116+
}
117+
}
118+
}
119+
};
120+
121+
template <typename DataT,
122+
typename ResT = DataT,
123+
typename Method = mkl_rng_dev::gaussian_method::by_default,
124+
typename IndexerT = ResT,
125+
typename UnaryOpT = ResT>
126+
struct RngStridedFunctor
127+
{
128+
private:
129+
const std::uint32_t seed_;
130+
const double mean_;
131+
const double stddev_;
132+
ResT *res_ = nullptr;
133+
IndexerT out_indexer_;
134+
135+
public:
136+
RngStridedFunctor(const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
137+
: seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
138+
{
139+
}
140+
141+
void operator()(sycl::id<1> wid) const
142+
{
143+
const auto res_offset = out_indexer_(wid.get(0));
144+
145+
// UnaryOpT op{};
146+
147+
auto engine = mkl_rng_dev::mrg32k3a(seed_);
148+
mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
149+
150+
res_[res_offset] = mkl_rng_dev::generate(distr, engine);
151+
}
152+
};
153+
} // namespace details
154+
} // namespace device
155+
} // namespace rng
156+
} // namespace ext
157+
} // namespace backend
158+
} // namespace dpnp

0 commit comments

Comments
 (0)