Skip to content

Commit ad4738f

Browse files
committed
Dispatching by an engine
1 parent 082835f commit ad4738f

13 files changed

+561
-212
lines changed

dpnp/backend/extensions/rng/device/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ endif()
4040

4141
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
4242

43+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
4344
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
4445
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
4546

dpnp/backend/extensions/rng/device/common_impl.hpp

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace py = pybind11;
5050
namespace mkl_rng_dev = oneapi::mkl::rng::device;
5151

5252
/*! @brief Functor for unary function evaluation on contiguous array */
53-
template <typename EngineDistrT,
53+
template <typename EngineBuilderT,
5454
typename DataT,
5555
typename GaussianDistrT,
5656
unsigned int items_per_wi = 4,
@@ -59,32 +59,28 @@ struct RngContigFunctor
5959
{
6060
private:
6161
// const std::uint32_t seed_;
62-
EngineDistrT engine_;
62+
EngineBuilderT engine_;
6363
GaussianDistrT distr_;
6464
DataT * const res_ = nullptr;
6565
const size_t nelems_;
6666

6767
public:
6868

69-
RngContigFunctor(EngineDistrT& engine, GaussianDistrT& distr, DataT *res, const size_t n_elems)
69+
RngContigFunctor(EngineBuilderT& engine, GaussianDistrT& distr, DataT *res, const size_t n_elems)
7070
: engine_(engine), distr_(distr), res_(res), nelems_(n_elems)
7171
{
7272
}
7373

7474
void operator()(sycl::nd_item<1> nd_it) const
7575
{
76-
// auto global_id = nd_it.get_global_id();
77-
78-
// constexpr std::size_t vec_sz = EngineT::vec_size;
76+
auto global_id = nd_it.get_global_id();
7977

8078
auto sg = nd_it.get_sub_group();
8179
const std::uint8_t sg_size = sg.get_local_range()[0];
8280
const std::uint8_t max_sg_size = sg.get_max_local_range()[0];
8381

84-
// auto engine = EngineT(seed_, nelems_ * global_id); // offset is questionable...
85-
86-
using EngineT = typename EngineDistrT::engine_type;
87-
EngineT engine = engine_();
82+
using EngineT = typename EngineBuilderT::EngineType;
83+
EngineT engine = engine_(nelems_ * global_id); // offset is questionable...
8884

8985
using DistrT = typename GaussianDistrT::distr_type;
9086
DistrT distr = distr_();
@@ -121,39 +117,6 @@ struct RngContigFunctor
121117
}
122118
}
123119
};
124-
125-
// template <typename DataT,
126-
// typename ResT = DataT,
127-
// typename Method = mkl_rng_dev::gaussian_method::by_default,
128-
// typename IndexerT = ResT,
129-
// typename UnaryOpT = ResT>
130-
// struct RngStridedFunctor
131-
// {
132-
// private:
133-
// const std::uint32_t seed_;
134-
// const double mean_;
135-
// const double stddev_;
136-
// ResT *res_ = nullptr;
137-
// IndexerT out_indexer_;
138-
139-
// public:
140-
// RngStridedFunctor(const std::uint32_t seed, const double mean, const double stddev, ResT *res_p, IndexerT out_indexer)
141-
// : seed_(seed), mean_(mean), stddev_(stddev), res_(res_p), out_indexer_(out_indexer)
142-
// {
143-
// }
144-
145-
// void operator()(sycl::id<1> wid) const
146-
// {
147-
// const auto res_offset = out_indexer_(wid.get(0));
148-
149-
// // UnaryOpT op{};
150-
151-
// auto engine = mkl_rng_dev::mrg32k3a(seed_);
152-
// mkl_rng_dev::gaussian<DataT, Method> distr(mean_, stddev_);
153-
154-
// res_[res_offset] = mkl_rng_dev::generate(distr, engine);
155-
// }
156-
// };
157120
} // namespace details
158121
} // namespace device
159122
} // namespace rng
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include "engine_base.hpp"
29+
30+
31+
namespace dpnp::backend::ext::rng::device::engine
32+
{
33+
template <typename EngineT, typename SeedT, typename OffsetT>
34+
class BaseBuilder {
35+
private:
36+
static constexpr std::uint8_t max_n = 10;
37+
38+
std::uint8_t no_of_seeds;
39+
std::uint8_t no_of_offsets;
40+
41+
std::array<SeedT, max_n> seeds{};
42+
std::array<OffsetT, max_n> offsets{};
43+
44+
public:
45+
BaseBuilder(EngineBase *engine)
46+
{
47+
auto seed_values = engine->get_seeds();
48+
no_of_seeds = seed_values.size();
49+
if (no_of_seeds > max_n) {
50+
throw std::runtime_error("");
51+
}
52+
53+
// TODO: implement a caster
54+
for (std::uint16_t i = 0; i < no_of_seeds; i++) {
55+
seeds[i] = static_cast<SeedT>(seed_values[i]);
56+
}
57+
58+
auto offset_values = engine->get_offsets();
59+
no_of_offsets = offset_values.size();
60+
if (no_of_offsets > max_n) {
61+
throw std::runtime_error("");
62+
}
63+
64+
// TODO: implement a caster
65+
for (std::uint16_t i = 0; i < no_of_seeds; i++) {
66+
offsets[i] = static_cast<OffsetT>(offset_values[i]);
67+
}
68+
}
69+
70+
inline auto operator()() const
71+
{
72+
switch (no_of_seeds) {
73+
case 1: {
74+
return EngineT({seeds[0]}, {offsets[0]});
75+
}
76+
// TODO: implement full switch
77+
default:
78+
break;
79+
}
80+
return EngineT();
81+
}
82+
83+
inline auto operator()(OffsetT offset) const
84+
{
85+
switch (no_of_seeds) {
86+
case 1: {
87+
return EngineT({seeds[0]}, offset);
88+
}
89+
// TODO: implement full switch
90+
default:
91+
break;
92+
}
93+
return EngineT();
94+
}
95+
96+
// TODO: remove
97+
void print() {
98+
std::cout << "list_of_seeds: ";
99+
for (auto &val: seeds) {
100+
std::cout << std::to_string(val) << ", ";
101+
}
102+
std::cout << std::endl;
103+
}
104+
};
105+
} // dpnp::backend::ext::rng::device::engine
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <sycl/sycl.hpp>
29+
30+
31+
namespace dpnp::backend::ext::rng::device::engine
32+
{
33+
class EngineType {
34+
public:
35+
enum Type : std::uint8_t {
36+
MRG32k3a = 0,
37+
Base, // must be the last always
38+
};
39+
40+
EngineType() = default;
41+
constexpr EngineType(Type type) : type_(type) {}
42+
43+
constexpr std::uint8_t id() const {
44+
return static_cast<std::uint8_t>(type_);
45+
}
46+
47+
static constexpr std::uint8_t base_id() {
48+
return EngineType(Base).id();
49+
}
50+
51+
private:
52+
Type type_;
53+
};
54+
55+
// A total number of supported engines == EngineType::Base
56+
constexpr int no_of_engines = EngineType::base_id();
57+
58+
class EngineBase {
59+
public:
60+
virtual ~EngineBase() {}
61+
virtual sycl::queue &get_queue() = 0;
62+
63+
virtual EngineType get_type() const noexcept {
64+
return EngineType::Base;
65+
}
66+
67+
virtual std::vector<std::uint64_t> get_seeds() const noexcept {
68+
return std::vector<std::uint64_t>();
69+
}
70+
71+
virtual std::vector<std::uint64_t> get_offsets() const noexcept {
72+
return std::vector<std::uint64_t>();
73+
}
74+
};
75+
} // dpnp::backend::ext::rng::device::engine
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
29+
namespace dpnp::backend::ext::rng::device::engine
30+
{
31+
template <typename Type>
32+
class Builder {};
33+
} // dpnp::backend::ext::rng::device::engine
34+
35+
#include "mrg32k3a_builder.hpp"
36+
#include "philox4x32x10_builder.hpp"
37+
#include "mcg31m1_builder.hpp"
38+
#include "mcg59_builder.hpp"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl/rng/device.hpp>
29+
30+
#include "engine_base.hpp"
31+
#include "base_builder.hpp"
32+
33+
namespace dpnp::backend::ext::rng::device::engine
34+
{
35+
namespace mkl_rng_dev = oneapi::mkl::rng::device;
36+
37+
template <std::int32_t VecSize>
38+
class Builder<mkl_rng_dev::mcg31m1<VecSize>> : public BaseBuilder<mkl_rng_dev::mcg31m1<VecSize>, std::uint32_t, std::uint64_t> {
39+
public:
40+
using EngineType = mkl_rng_dev::mcg31m1<VecSize>;
41+
42+
Builder(EngineBase *engine) : BaseBuilder<EngineType, std::uint32_t, std::uint64_t>(engine) {}
43+
};
44+
} // dpnp::backend::ext::rng::device::engine

0 commit comments

Comments
 (0)