|
39 | 39 | namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
|
40 | 40 | namespace py = pybind11;
|
41 | 41 |
|
42 |
| -namespace statistics |
| 42 | +namespace statistics::common |
43 | 43 | {
|
44 |
| -namespace common |
45 |
| -{ |
46 |
| - |
47 | 44 | template <typename T, typename Rest>
|
48 | 45 | struct one_of
|
49 | 46 | {
|
@@ -97,6 +94,32 @@ using DTypePair = std::pair<DType, DType>;
|
97 | 94 | using SupportedDTypeList = std::vector<DType>;
|
98 | 95 | using SupportedDTypeList2 = std::vector<DTypePair>;
|
99 | 96 |
|
| 97 | +template <typename FnT, |
| 98 | + typename SupportedTypes, |
| 99 | + template <typename> |
| 100 | + typename Func> |
| 101 | +struct TableBuilder |
| 102 | +{ |
| 103 | + template <typename _FnT, typename T> |
| 104 | + struct impl |
| 105 | + { |
| 106 | + static constexpr bool is_defined = one_of_v<T, SupportedTypes>; |
| 107 | + |
| 108 | + _FnT get() |
| 109 | + { |
| 110 | + if constexpr (is_defined) { |
| 111 | + return Func<T>::impl; |
| 112 | + } |
| 113 | + else { |
| 114 | + return nullptr; |
| 115 | + } |
| 116 | + } |
| 117 | + }; |
| 118 | + |
| 119 | + using type = |
| 120 | + dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>; |
| 121 | +}; |
| 122 | + |
100 | 123 | template <typename FnT,
|
101 | 124 | typename SupportedTypes,
|
102 | 125 | template <typename, typename>
|
@@ -124,6 +147,78 @@ struct TableBuilder2
|
124 | 147 | dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
|
125 | 148 | };
|
126 | 149 |
|
| 150 | +template <typename FnT> |
| 151 | +class DispatchTable |
| 152 | +{ |
| 153 | +public: |
| 154 | + DispatchTable(std::string name) : name(name) {} |
| 155 | + |
| 156 | + template <typename SupportedTypes, template <typename> typename Func> |
| 157 | + void populate_dispatch_table() |
| 158 | + { |
| 159 | + using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type; |
| 160 | + TBulder builder; |
| 161 | + |
| 162 | + builder.populate_dispatch_vector(table); |
| 163 | + populate_supported_types(); |
| 164 | + } |
| 165 | + |
| 166 | + FnT get_unsafe(int _typenum) const |
| 167 | + { |
| 168 | + auto array_types = dpctl_td_ns::usm_ndarray_types(); |
| 169 | + const int type_id = array_types.typenum_to_lookup_id(_typenum); |
| 170 | + |
| 171 | + return table[type_id]; |
| 172 | + } |
| 173 | + |
| 174 | + FnT get(int _typenum) const |
| 175 | + { |
| 176 | + auto fn = get_unsafe(_typenum); |
| 177 | + |
| 178 | + if (fn == nullptr) { |
| 179 | + auto array_types = dpctl_td_ns::usm_ndarray_types(); |
| 180 | + const int _type_id = array_types.typenum_to_lookup_id(_typenum); |
| 181 | + |
| 182 | + py::dtype _dtype = dtype_from_typenum(_type_id); |
| 183 | + auto _type_pos = std::find(supported_types.begin(), |
| 184 | + supported_types.end(), _dtype); |
| 185 | + if (_type_pos == supported_types.end()) { |
| 186 | + py::str types = py::str(py::cast(supported_types)); |
| 187 | + py::str dtype = py::str(_dtype); |
| 188 | + |
| 189 | + py::str err_msg = |
| 190 | + py::str("'" + name + "' has unsupported type '") + dtype + |
| 191 | + py::str("'." |
| 192 | + " Supported types are: ") + |
| 193 | + types; |
| 194 | + |
| 195 | + throw py::value_error(static_cast<std::string>(err_msg)); |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + return fn; |
| 200 | + } |
| 201 | + |
| 202 | + const SupportedDTypeList &get_all_supported_types() const |
| 203 | + { |
| 204 | + return supported_types; |
| 205 | + } |
| 206 | + |
| 207 | +private: |
| 208 | + void populate_supported_types() |
| 209 | + { |
| 210 | + for (int i = 0; i < dpctl_td_ns::num_types; ++i) { |
| 211 | + if (table[i] != nullptr) { |
| 212 | + supported_types.emplace_back(dtype_from_typenum(i)); |
| 213 | + } |
| 214 | + } |
| 215 | + } |
| 216 | + |
| 217 | + std::string name; |
| 218 | + SupportedDTypeList supported_types; |
| 219 | + Table<FnT> table; |
| 220 | +}; |
| 221 | + |
127 | 222 | template <typename FnT>
|
128 | 223 | class DispatchTable2
|
129 | 224 | {
|
@@ -288,5 +383,4 @@ class DispatchTable2
|
288 | 383 | Table2<FnT> table;
|
289 | 384 | };
|
290 | 385 |
|
291 |
| -} // namespace common |
292 |
| -} // namespace statistics |
| 386 | +} // namespace statistics::common |
0 commit comments