Skip to content

Commit e90d30b

Browse files
Introducing dispatch_table
1 parent eb549ce commit e90d30b

File tree

9 files changed

+459
-184
lines changed

9 files changed

+459
-184
lines changed

dpnp/backend/extensions/statistics/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

dpnp/backend/extensions/statistics/common.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26+
#include <pybind11/pybind11.h>
2627
#include "common.hpp"
28+
#include "utils/type_dispatch.hpp"
29+
#include <iostream>
30+
31+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
2732

2833
namespace statistics
2934
{
@@ -78,5 +83,42 @@ size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve)
7883
return local_mem_size - reserve;
7984
}
8085

86+
pybind11::dtype dtype_from_typenum(int dst_typenum)
87+
{
88+
dpctl_td_ns::typenum_t dst_typenum_t = static_cast<dpctl_td_ns::typenum_t>(dst_typenum);
89+
switch (dst_typenum_t) {
90+
case dpctl_td_ns::typenum_t::BOOL:
91+
return py::dtype("?");
92+
case dpctl_td_ns::typenum_t::INT8:
93+
return py::dtype("i1");
94+
case dpctl_td_ns::typenum_t::UINT8:
95+
return py::dtype("u1");
96+
case dpctl_td_ns::typenum_t::INT16:
97+
return py::dtype("i2");
98+
case dpctl_td_ns::typenum_t::UINT16:
99+
return py::dtype("u2");
100+
case dpctl_td_ns::typenum_t::INT32:
101+
return py::dtype("i4");
102+
case dpctl_td_ns::typenum_t::UINT32:
103+
return py::dtype("u4");
104+
case dpctl_td_ns::typenum_t::INT64:
105+
return py::dtype("i8");
106+
case dpctl_td_ns::typenum_t::UINT64:
107+
return py::dtype("u8");
108+
case dpctl_td_ns::typenum_t::HALF:
109+
return py::dtype("f2");
110+
case dpctl_td_ns::typenum_t::FLOAT:
111+
return py::dtype("f4");
112+
case dpctl_td_ns::typenum_t::DOUBLE:
113+
return py::dtype("f8");
114+
case dpctl_td_ns::typenum_t::CFLOAT:
115+
return py::dtype("c8");
116+
case dpctl_td_ns::typenum_t::CDOUBLE:
117+
return py::dtype("c16");
118+
default:
119+
throw py::value_error("Unrecognized dst_typeid");
120+
}
121+
}
122+
81123
} // namespace common
82124
} // namespace statistics

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
#include <functional>
3030
#include <tuple>
3131
#include <type_traits>
32+
#include <pybind11/pybind11.h>
33+
#include <pybind11/numpy.h>
3234

3335
#include <sycl/sycl.hpp>
34-
3536
#include "utils/math_utils.hpp"
3637

38+
3739
namespace statistics
3840
{
3941
namespace common
@@ -180,5 +182,7 @@ sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
180182
sycl::nd_range<1>
181183
make_ndrange(size_t global_size, size_t local_range, size_t work_per_item);
182184

185+
pybind11::dtype dtype_from_typenum(int dst_typenum);
186+
183187
} // namespace common
184188
} // namespace statistics
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <vector>
29+
#include <unordered_set>
30+
31+
#include "utils/type_dispatch.hpp"
32+
#include <pybind11/pybind11.h>
33+
#include <pybind11/numpy.h>
34+
#include <pybind11/stl.h>
35+
#include <sycl/sycl.hpp>
36+
37+
#include "common.hpp"
38+
39+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
40+
namespace py = pybind11;
41+
42+
namespace statistics
43+
{
44+
namespace common
45+
{
46+
47+
template<typename T, typename Rest>
48+
struct one_of
49+
{
50+
static_assert(std::is_same_v<Rest, std::tuple<>>, "one_of: second parameter cannot be empty std::tuple");
51+
static_assert(false, "one_of: second parameter must be std::tuple");
52+
};
53+
54+
template<typename T, typename Top, typename ... Rest>
55+
struct one_of<T, std::tuple<Top, Rest...>>
56+
{
57+
static constexpr bool value = std::is_same_v<T, Top> || one_of<T, std::tuple<Rest...>>::value;
58+
};
59+
60+
template<typename T, typename Top>
61+
struct one_of<T, std::tuple<Top>>
62+
{
63+
static constexpr bool value = std::is_same_v<T, Top>;
64+
};
65+
66+
template<typename T, typename Rest>
67+
constexpr bool one_of_v = one_of<T, Rest>::value;
68+
69+
template <typename FnT>
70+
using Table = FnT[dpctl_td_ns::num_types];
71+
template <typename FnT>
72+
using Table2 = Table<FnT>[dpctl_td_ns::num_types];
73+
74+
using TypeId = int32_t;
75+
using TypesPair = std::pair<TypeId, TypeId>;
76+
77+
struct int_pair_hash
78+
{
79+
inline size_t operator()(const TypesPair &p) const
80+
{
81+
std::hash<size_t> hasher;
82+
return hasher(size_t(p.first) << (8*sizeof(TypeId)) | size_t(p.second));
83+
}
84+
};
85+
86+
using SupportedTypesList = std::vector<TypeId>;
87+
using SupportedTypesList2 = std::vector<TypesPair>;
88+
using SupportedTypesSet = std::unordered_set<TypeId>;
89+
using SupportedTypesSet2 = std::unordered_set<TypesPair, int_pair_hash>;
90+
91+
using DType = py::dtype;
92+
using DTypePair = std::pair<DType, DType>;
93+
94+
using SupportedDTypeList = std::vector<DType>;
95+
using SupportedDTypeList2 = std::vector<DTypePair>;
96+
97+
template <typename FnT,
98+
typename SupportedTypes,
99+
template <typename, typename>
100+
typename Func>
101+
struct TableBuilder2
102+
{
103+
template<typename _FnT, typename T1, typename T2>
104+
struct impl {
105+
static constexpr bool is_defined = one_of_v<std::tuple<T1, T2>, SupportedTypes>;
106+
107+
_FnT get()
108+
{
109+
if constexpr (is_defined) {
110+
return Func<T1, T2>::impl;
111+
}
112+
else {
113+
return nullptr;
114+
}
115+
}
116+
};
117+
118+
using type = dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
119+
};
120+
121+
template<typename FnT>
122+
class DispatchTable2
123+
{
124+
public:
125+
DispatchTable2(std::string first_name, std::string second_name)
126+
: first_name(first_name), second_name(second_name)
127+
{
128+
}
129+
130+
template<typename SupportedTypes,
131+
template <typename, typename> typename Func>
132+
void populate_dispatch_table()
133+
{
134+
using TBulder = typename TableBuilder2<FnT, SupportedTypes, Func>::type;
135+
TBulder builder;
136+
137+
builder.populate_dispatch_table(table);
138+
populate_supported_types();
139+
}
140+
141+
FnT get_unsafe(int first_typenum, int second_typenum) const
142+
{
143+
auto array_types = dpctl_td_ns::usm_ndarray_types();
144+
const int first_type_id = array_types.typenum_to_lookup_id(first_typenum);
145+
const int second_type_id = array_types.typenum_to_lookup_id(second_typenum);
146+
147+
return table[first_type_id][second_type_id];
148+
}
149+
150+
FnT get(int first_typenum, int second_typenum) const
151+
{
152+
auto fn = get_unsafe(first_typenum, second_typenum);
153+
154+
if (fn == nullptr)
155+
{
156+
auto array_types = dpctl_td_ns::usm_ndarray_types();
157+
const int first_type_id = array_types.typenum_to_lookup_id(first_typenum);
158+
const int second_type_id = array_types.typenum_to_lookup_id(second_typenum);
159+
160+
py::dtype first_dtype = dtype_from_typenum(first_type_id);
161+
auto first_type_pos = std::find(supported_first_type.begin(), supported_first_type.end(), first_dtype);
162+
if (first_type_pos == supported_first_type.end())
163+
{
164+
py::str types = py::str(py::cast(supported_first_type));
165+
py::str dtype = py::str(first_dtype);
166+
167+
py::str err_msg = py::str("'" + first_name + "' has unsuported type '") + dtype + py::str("'."
168+
" Supported types are: ") + types;
169+
170+
171+
throw py::value_error(static_cast<std::string>(err_msg));
172+
}
173+
174+
py::dtype second_dtype = dtype_from_typenum(second_type_id);
175+
auto second_type_pos = std::find(supported_second_type.begin(), supported_second_type.end(), second_dtype);
176+
if (second_type_pos == supported_second_type.end())
177+
{
178+
py::str types = py::str(py::cast(supported_second_type));
179+
py::str dtype = py::str(second_dtype);
180+
181+
py::str err_msg = py::str("'" + second_name + "' has unsuported type '") + dtype + py::str("'."
182+
" Supported types are: ") + types;
183+
184+
throw py::value_error(static_cast<std::string>(err_msg));
185+
}
186+
187+
py::str first_dtype_str = py::str(first_dtype);
188+
py::str second_dtype_str = py::str(second_dtype);
189+
py::str types = py::str(py::cast(all_supported_types));
190+
191+
py::str err_msg = py::str("'" + first_name + "' and '" + second_name + "' has unsupported types combination: ('") +
192+
first_dtype_str + py::str("', '") + second_dtype_str + py::str("')."
193+
" Supported types combinations are: ") + types;
194+
195+
throw py::value_error(static_cast<std::string>(err_msg));
196+
}
197+
198+
return fn;
199+
}
200+
201+
const SupportedDTypeList &get_supported_first_type() const
202+
{
203+
return supported_first_type;
204+
}
205+
206+
const SupportedDTypeList &get_supported_second_type() const
207+
{
208+
return supported_second_type;
209+
}
210+
211+
const SupportedDTypeList2 &get_all_supported_types() const
212+
{
213+
return all_supported_types;
214+
}
215+
216+
private:
217+
void populate_supported_types()
218+
{
219+
SupportedTypesSet first_supported_types_set;
220+
SupportedTypesSet second_supported_types_set;
221+
SupportedTypesSet2 all_supported_types_set;
222+
223+
for (int i = 0; i < dpctl_td_ns::num_types; ++i)
224+
{
225+
for (int j = 0; j < dpctl_td_ns::num_types; ++j)
226+
{
227+
if (table[i][j] != nullptr)
228+
{
229+
all_supported_types_set.emplace(i, j);
230+
first_supported_types_set.emplace(i);
231+
second_supported_types_set.emplace(j);
232+
}
233+
}
234+
}
235+
236+
auto to_supported_dtype_list = [](const auto& supported_set, auto& supported_list) {
237+
SupportedTypesList lst(supported_set.begin(), supported_set.end());
238+
std::sort(lst.begin(), lst.end());
239+
supported_list.resize(supported_set.size());
240+
std::transform(lst.begin(), lst.end(), supported_list.begin(),
241+
[](TypeId i) { return dtype_from_typenum(i); });
242+
};
243+
244+
to_supported_dtype_list(first_supported_types_set, supported_first_type);
245+
to_supported_dtype_list(second_supported_types_set, supported_second_type);
246+
247+
SupportedTypesList2 lst(all_supported_types_set.begin(), all_supported_types_set.end());
248+
std::sort(lst.begin(), lst.end());
249+
all_supported_types.resize(all_supported_types_set.size());
250+
std::transform(lst.begin(), lst.end(), all_supported_types.begin(),
251+
[](TypesPair p) { return DTypePair(dtype_from_typenum(p.first), dtype_from_typenum(p.second)); });
252+
}
253+
254+
std::string first_name;
255+
std::string second_name;
256+
257+
SupportedDTypeList supported_first_type;
258+
SupportedDTypeList supported_second_type;
259+
SupportedDTypeList2 all_supported_types;
260+
261+
Table2<FnT> table;
262+
};
263+
264+
} // namespace common
265+
} // namespace statistics

0 commit comments

Comments
 (0)