Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dpnp/backend/extensions/statistics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@

set(python_module_name _statistics_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/bincount.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace histogram
namespace statistics::histogram
{
struct Bincount
{
Expand Down Expand Up @@ -62,5 +60,4 @@ struct Bincount
};

void populate_bincount(py::module_ m);
} // namespace histogram
} // namespace statistics
} // namespace statistics::histogram
8 changes: 2 additions & 6 deletions dpnp/backend/extensions/statistics/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
namespace statistics::common
{
namespace common
{

size_t get_max_local_size(const sycl::device &device)
{
constexpr const int default_max_cpu_local_size = 256;
Expand Down Expand Up @@ -120,5 +117,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
}
}

} // namespace common
} // namespace statistics
} // namespace statistics::common
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
#include "utils/math_utils.hpp"
// clang-format on

namespace statistics
{
namespace common
namespace statistics::common
{

template <typename N, typename D>
Expand Down Expand Up @@ -200,5 +198,4 @@ sycl::nd_range<1>
// headers of dpctl.
pybind11::dtype dtype_from_typenum(int dst_typenum);

} // namespace common
} // namespace statistics
} // namespace statistics::common
106 changes: 100 additions & 6 deletions dpnp/backend/extensions/statistics/dispatch_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
namespace py = pybind11;

namespace statistics
namespace statistics::common
{
namespace common
{

template <typename T, typename Rest>
struct one_of
{
Expand Down Expand Up @@ -97,6 +94,32 @@ using DTypePair = std::pair<DType, DType>;
using SupportedDTypeList = std::vector<DType>;
using SupportedDTypeList2 = std::vector<DTypePair>;

template <typename FnT,
typename SupportedTypes,
template <typename>
typename Func>
struct TableBuilder
{
template <typename _FnT, typename T>
struct impl
{
static constexpr bool is_defined = one_of_v<T, SupportedTypes>;

_FnT get()
{
if constexpr (is_defined) {
return Func<T>::impl;
}
else {
return nullptr;
}
}
};

using type =
dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
};

template <typename FnT,
typename SupportedTypes,
template <typename, typename>
Expand Down Expand Up @@ -124,6 +147,78 @@ struct TableBuilder2
dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
};

template <typename FnT>
class DispatchTable
{
public:
DispatchTable(std::string name) : name(name) {}

template <typename SupportedTypes, template <typename> typename Func>
void populate_dispatch_table()
{
using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
TBulder builder;

builder.populate_dispatch_vector(table);
populate_supported_types();
}

FnT get_unsafe(int _typenum) const
{
auto array_types = dpctl_td_ns::usm_ndarray_types();
const int type_id = array_types.typenum_to_lookup_id(_typenum);

return table[type_id];
}

FnT get(int _typenum) const
{
auto fn = get_unsafe(_typenum);

if (fn == nullptr) {
auto array_types = dpctl_td_ns::usm_ndarray_types();
const int _type_id = array_types.typenum_to_lookup_id(_typenum);

py::dtype _dtype = dtype_from_typenum(_type_id);
auto _type_pos = std::find(supported_types.begin(),
supported_types.end(), _dtype);
if (_type_pos == supported_types.end()) {
py::str types = py::str(py::cast(supported_types));
py::str dtype = py::str(_dtype);

py::str err_msg =
py::str("'" + name + "' has unsupported type '") + dtype +
py::str("'."
" Supported types are: ") +
types;

throw py::value_error(static_cast<std::string>(err_msg));
}
}

return fn;
}

const SupportedDTypeList &get_all_supported_types() const
{
return supported_types;
}

private:
void populate_supported_types()
{
for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
if (table[i] != nullptr) {
supported_types.emplace_back(dtype_from_typenum(i));
}
}
}

std::string name;
SupportedDTypeList supported_types;
Table<FnT> table;
};

template <typename FnT>
class DispatchTable2
{
Expand Down Expand Up @@ -288,5 +383,4 @@ class DispatchTable2
Table2<FnT> table;
};

} // namespace common
} // namespace statistics
} // namespace statistics::common
4 changes: 1 addition & 3 deletions dpnp/backend/extensions/statistics/histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
#include <algorithm>
#include <complex>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <tuple>
#include <vector>

#include <pybind11/pybind11.h>
Expand Down
9 changes: 3 additions & 6 deletions dpnp/backend/extensions/statistics/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@

#include "dispatch_table.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
// namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace histogram
namespace statistics::histogram
{
struct Histogram
{
Expand All @@ -59,5 +57,4 @@ struct Histogram
};

void populate_histogram(py::module_ m);
} // namespace histogram
} // namespace statistics
} // namespace statistics::histogram
Loading
Loading