Skip to content

Commit 7fc1df1

Browse files
Introducing dispatch_table
1 parent 5bbc680 commit 7fc1df1

File tree

9 files changed

+496
-187
lines changed

9 files changed

+496
-187
lines changed

dpnp/backend/extensions/statistics/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

dpnp/backend/extensions/statistics/common.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
//*****************************************************************************
2525

2626
#include "common.hpp"
27+
#include "utils/type_dispatch.hpp"
28+
#include <pybind11/pybind11.h>
29+
30+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
2731

2832
namespace statistics
2933
{
@@ -78,5 +82,43 @@ size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve)
7882
return local_mem_size - reserve;
7983
}
8084

85+
pybind11::dtype dtype_from_typenum(int dst_typenum)
86+
{
87+
dpctl_td_ns::typenum_t dst_typenum_t =
88+
static_cast<dpctl_td_ns::typenum_t>(dst_typenum);
89+
switch (dst_typenum_t) {
90+
case dpctl_td_ns::typenum_t::BOOL:
91+
return py::dtype("?");
92+
case dpctl_td_ns::typenum_t::INT8:
93+
return py::dtype("i1");
94+
case dpctl_td_ns::typenum_t::UINT8:
95+
return py::dtype("u1");
96+
case dpctl_td_ns::typenum_t::INT16:
97+
return py::dtype("i2");
98+
case dpctl_td_ns::typenum_t::UINT16:
99+
return py::dtype("u2");
100+
case dpctl_td_ns::typenum_t::INT32:
101+
return py::dtype("i4");
102+
case dpctl_td_ns::typenum_t::UINT32:
103+
return py::dtype("u4");
104+
case dpctl_td_ns::typenum_t::INT64:
105+
return py::dtype("i8");
106+
case dpctl_td_ns::typenum_t::UINT64:
107+
return py::dtype("u8");
108+
case dpctl_td_ns::typenum_t::HALF:
109+
return py::dtype("f2");
110+
case dpctl_td_ns::typenum_t::FLOAT:
111+
return py::dtype("f4");
112+
case dpctl_td_ns::typenum_t::DOUBLE:
113+
return py::dtype("f8");
114+
case dpctl_td_ns::typenum_t::CFLOAT:
115+
return py::dtype("c8");
116+
case dpctl_td_ns::typenum_t::CDOUBLE:
117+
return py::dtype("c16");
118+
default:
119+
throw py::value_error("Unrecognized dst_typeid");
120+
}
121+
}
122+
81123
} // namespace common
82124
} // namespace statistics

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
#pragma once
2727

2828
#include <complex>
29-
#include <functional>
30-
#include <tuple>
31-
#include <type_traits>
29+
#include <pybind11/numpy.h>
30+
#include <pybind11/pybind11.h>
3231

32+
// clang-format off
33+
// math_utils.hpp doesn't include sycl header but uses sycl types
34+
// so sycl.hpp must be included before math_utils.hpp
3335
#include <sycl/sycl.hpp>
34-
3536
#include "utils/math_utils.hpp"
37+
// clang-format on
3638

3739
namespace statistics
3840
{
@@ -180,5 +182,9 @@ sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
180182
sycl::nd_range<1>
181183
make_ndrange(size_t global_size, size_t local_range, size_t work_per_item);
182184

185+
// This function is a copy from dpctl because it is not available in the public
186+
// headers of dpctl.
187+
pybind11::dtype dtype_from_typenum(int dst_typenum);
188+
183189
} // namespace common
184190
} // namespace statistics
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <unordered_set>
29+
#include <vector>
30+
31+
#include "utils/type_dispatch.hpp"
32+
#include <pybind11/numpy.h>
33+
#include <pybind11/pybind11.h>
34+
#include <pybind11/stl.h>
35+
#include <sycl/sycl.hpp>
36+
37+
#include "common.hpp"
38+
39+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
40+
namespace py = pybind11;
41+
42+
namespace statistics
43+
{
44+
namespace common
45+
{
46+
47+
template <typename T, typename Rest>
48+
struct one_of
49+
{
50+
static_assert(std::is_same_v<Rest, std::tuple<>>,
51+
"one_of: second parameter cannot be empty std::tuple");
52+
static_assert(false, "one_of: second parameter must be std::tuple");
53+
};
54+
55+
template <typename T, typename Top, typename... Rest>
56+
struct one_of<T, std::tuple<Top, Rest...>>
57+
{
58+
static constexpr bool value =
59+
std::is_same_v<T, Top> || one_of<T, std::tuple<Rest...>>::value;
60+
};
61+
62+
template <typename T, typename Top>
63+
struct one_of<T, std::tuple<Top>>
64+
{
65+
static constexpr bool value = std::is_same_v<T, Top>;
66+
};
67+
68+
template <typename T, typename Rest>
69+
constexpr bool one_of_v = one_of<T, Rest>::value;
70+
71+
template <typename FnT>
72+
using Table = FnT[dpctl_td_ns::num_types];
73+
template <typename FnT>
74+
using Table2 = Table<FnT>[dpctl_td_ns::num_types];
75+
76+
using TypeId = int32_t;
77+
using TypesPair = std::pair<TypeId, TypeId>;
78+
79+
struct int_pair_hash
80+
{
81+
inline size_t operator()(const TypesPair &p) const
82+
{
83+
std::hash<size_t> hasher;
84+
return hasher(size_t(p.first) << (8 * sizeof(TypeId)) |
85+
size_t(p.second));
86+
}
87+
};
88+
89+
using SupportedTypesList = std::vector<TypeId>;
90+
using SupportedTypesList2 = std::vector<TypesPair>;
91+
using SupportedTypesSet = std::unordered_set<TypeId>;
92+
using SupportedTypesSet2 = std::unordered_set<TypesPair, int_pair_hash>;
93+
94+
using DType = py::dtype;
95+
using DTypePair = std::pair<DType, DType>;
96+
97+
using SupportedDTypeList = std::vector<DType>;
98+
using SupportedDTypeList2 = std::vector<DTypePair>;
99+
100+
template <typename FnT,
101+
typename SupportedTypes,
102+
template <typename, typename>
103+
typename Func>
104+
struct TableBuilder2
105+
{
106+
template <typename _FnT, typename T1, typename T2>
107+
struct impl
108+
{
109+
static constexpr bool is_defined =
110+
one_of_v<std::tuple<T1, T2>, SupportedTypes>;
111+
112+
_FnT get()
113+
{
114+
if constexpr (is_defined) {
115+
return Func<T1, T2>::impl;
116+
}
117+
else {
118+
return nullptr;
119+
}
120+
}
121+
};
122+
123+
using type =
124+
dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
125+
};
126+
127+
template <typename FnT>
128+
class DispatchTable2
129+
{
130+
public:
131+
DispatchTable2(std::string first_name, std::string second_name)
132+
: first_name(first_name), second_name(second_name)
133+
{
134+
}
135+
136+
template <typename SupportedTypes,
137+
template <typename, typename>
138+
typename Func>
139+
void populate_dispatch_table()
140+
{
141+
using TBulder = typename TableBuilder2<FnT, SupportedTypes, Func>::type;
142+
TBulder builder;
143+
144+
builder.populate_dispatch_table(table);
145+
populate_supported_types();
146+
}
147+
148+
FnT get_unsafe(int first_typenum, int second_typenum) const
149+
{
150+
auto array_types = dpctl_td_ns::usm_ndarray_types();
151+
const int first_type_id =
152+
array_types.typenum_to_lookup_id(first_typenum);
153+
const int second_type_id =
154+
array_types.typenum_to_lookup_id(second_typenum);
155+
156+
return table[first_type_id][second_type_id];
157+
}
158+
159+
FnT get(int first_typenum, int second_typenum) const
160+
{
161+
auto fn = get_unsafe(first_typenum, second_typenum);
162+
163+
if (fn == nullptr) {
164+
auto array_types = dpctl_td_ns::usm_ndarray_types();
165+
const int first_type_id =
166+
array_types.typenum_to_lookup_id(first_typenum);
167+
const int second_type_id =
168+
array_types.typenum_to_lookup_id(second_typenum);
169+
170+
py::dtype first_dtype = dtype_from_typenum(first_type_id);
171+
auto first_type_pos =
172+
std::find(supported_first_type.begin(),
173+
supported_first_type.end(), first_dtype);
174+
if (first_type_pos == supported_first_type.end()) {
175+
py::str types = py::str(py::cast(supported_first_type));
176+
py::str dtype = py::str(first_dtype);
177+
178+
py::str err_msg =
179+
py::str("'" + first_name + "' has unsupported type '") +
180+
dtype +
181+
py::str("'."
182+
" Supported types are: ") +
183+
types;
184+
185+
throw py::value_error(static_cast<std::string>(err_msg));
186+
}
187+
188+
py::dtype second_dtype = dtype_from_typenum(second_type_id);
189+
auto second_type_pos =
190+
std::find(supported_second_type.begin(),
191+
supported_second_type.end(), second_dtype);
192+
if (second_type_pos == supported_second_type.end()) {
193+
py::str types = py::str(py::cast(supported_second_type));
194+
py::str dtype = py::str(second_dtype);
195+
196+
py::str err_msg =
197+
py::str("'" + second_name + "' has unsupported type '") +
198+
dtype +
199+
py::str("'."
200+
" Supported types are: ") +
201+
types;
202+
203+
throw py::value_error(static_cast<std::string>(err_msg));
204+
}
205+
206+
py::str first_dtype_str = py::str(first_dtype);
207+
py::str second_dtype_str = py::str(second_dtype);
208+
py::str types = py::str(py::cast(all_supported_types));
209+
210+
py::str err_msg =
211+
py::str("'" + first_name + "' and '" + second_name +
212+
"' has unsupported types combination: ('") +
213+
first_dtype_str + py::str("', '") + second_dtype_str +
214+
py::str("')."
215+
" Supported types combinations are: ") +
216+
types;
217+
218+
throw py::value_error(static_cast<std::string>(err_msg));
219+
}
220+
221+
return fn;
222+
}
223+
224+
const SupportedDTypeList &get_supported_first_type() const
225+
{
226+
return supported_first_type;
227+
}
228+
229+
const SupportedDTypeList &get_supported_second_type() const
230+
{
231+
return supported_second_type;
232+
}
233+
234+
const SupportedDTypeList2 &get_all_supported_types() const
235+
{
236+
return all_supported_types;
237+
}
238+
239+
private:
240+
void populate_supported_types()
241+
{
242+
SupportedTypesSet first_supported_types_set;
243+
SupportedTypesSet second_supported_types_set;
244+
SupportedTypesSet2 all_supported_types_set;
245+
246+
for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
247+
for (int j = 0; j < dpctl_td_ns::num_types; ++j) {
248+
if (table[i][j] != nullptr) {
249+
all_supported_types_set.emplace(i, j);
250+
first_supported_types_set.emplace(i);
251+
second_supported_types_set.emplace(j);
252+
}
253+
}
254+
}
255+
256+
auto to_supported_dtype_list = [](const auto &supported_set,
257+
auto &supported_list) {
258+
SupportedTypesList lst(supported_set.begin(), supported_set.end());
259+
std::sort(lst.begin(), lst.end());
260+
supported_list.resize(supported_set.size());
261+
std::transform(lst.begin(), lst.end(), supported_list.begin(),
262+
[](TypeId i) { return dtype_from_typenum(i); });
263+
};
264+
265+
to_supported_dtype_list(first_supported_types_set,
266+
supported_first_type);
267+
to_supported_dtype_list(second_supported_types_set,
268+
supported_second_type);
269+
270+
SupportedTypesList2 lst(all_supported_types_set.begin(),
271+
all_supported_types_set.end());
272+
std::sort(lst.begin(), lst.end());
273+
all_supported_types.resize(all_supported_types_set.size());
274+
std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
275+
[](TypesPair p) {
276+
return DTypePair(dtype_from_typenum(p.first),
277+
dtype_from_typenum(p.second));
278+
});
279+
}
280+
281+
std::string first_name;
282+
std::string second_name;
283+
284+
SupportedDTypeList supported_first_type;
285+
SupportedDTypeList supported_second_type;
286+
SupportedDTypeList2 all_supported_types;
287+
288+
Table2<FnT> table;
289+
};
290+
291+
} // namespace common
292+
} // namespace statistics

0 commit comments

Comments
 (0)