Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

namespace type_utils = dpctl::tensor::type_utils;

namespace statistics::common
namespace ext::common
{

template <typename N, typename D>
Expand Down Expand Up @@ -185,4 +185,6 @@ sycl::nd_range<1>
// headers of dpctl.
pybind11::dtype dtype_from_typenum(int dst_typenum);

} // namespace statistics::common
} // namespace ext::common

#include "ext/details/common_internal.hpp"
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include "common.hpp"
#include "ext/common.hpp"
#include "utils/type_dispatch.hpp"
#include <pybind11/pybind11.h>

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics::common
namespace ext::common
{
size_t get_max_local_size(const sycl::device &device)
inline size_t get_max_local_size(const sycl::device &device)
{
constexpr const int default_max_cpu_local_size = 256;
constexpr const int default_max_gpu_local_size = 0;
Expand All @@ -40,9 +40,9 @@ size_t get_max_local_size(const sycl::device &device)
default_max_gpu_local_size);
}

size_t get_max_local_size(const sycl::device &device,
int cpu_local_size_limit,
int gpu_local_size_limit)
inline size_t get_max_local_size(const sycl::device &device,
int cpu_local_size_limit,
int gpu_local_size_limit)
{
int max_work_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
Expand All @@ -56,30 +56,31 @@ size_t get_max_local_size(const sycl::device &device,
return max_work_group_size;
}

sycl::nd_range<1>
inline sycl::nd_range<1>
make_ndrange(size_t global_size, size_t local_range, size_t work_per_item)
{
return make_ndrange(sycl::range<1>(global_size),
sycl::range<1>(local_range),
sycl::range<1>(work_per_item));
}

size_t get_local_mem_size_in_bytes(const sycl::device &device)
inline size_t get_local_mem_size_in_bytes(const sycl::device &device)
{
// Reserving 1kb for runtime needs
constexpr const size_t reserve = 1024;

return get_local_mem_size_in_bytes(device, reserve);
}

size_t get_local_mem_size_in_bytes(const sycl::device &device, size_t reserve)
inline size_t get_local_mem_size_in_bytes(const sycl::device &device,
size_t reserve)
{
size_t local_mem_size =
device.get_info<sycl::info::device::local_mem_size>();
return local_mem_size - reserve;
}

pybind11::dtype dtype_from_typenum(int dst_typenum)
inline pybind11::dtype dtype_from_typenum(int dst_typenum)
{
dpctl_td_ns::typenum_t dst_typenum_t =
static_cast<dpctl_td_ns::typenum_t>(dst_typenum);
Expand Down Expand Up @@ -117,4 +118,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
}
}

} // namespace statistics::common
} // namespace ext::common
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,13 @@
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include "validation_utils.hpp"
#include "ext/validation_utils.hpp"
#include "utils/memory_overlap.hpp"

using statistics::validation::array_names;
using statistics::validation::array_ptr;

namespace
namespace ext::validation
{

sycl::queue get_queue(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs)
inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs)
{
auto it = std::find_if(inputs.cbegin(), inputs.cend(),
[](const array_ptr &arr) { return arr != nullptr; });
Expand All @@ -51,11 +47,8 @@ sycl::queue get_queue(const std::vector<array_ptr> &inputs,

throw py::value_error("No input or output arrays found");
}
} // namespace

namespace statistics::validation
{
std::string name_of(const array_ptr &arr, const array_names &names)
inline std::string name_of(const array_ptr &arr, const array_names &names)
{
auto name_it = names.find(arr);
assert(name_it != names.end());
Expand All @@ -66,8 +59,8 @@ std::string name_of(const array_ptr &arr, const array_names &names)
return "'unknown'";
}

void check_writable(const std::vector<array_ptr> &arrays,
const array_names &names)
inline void check_writable(const std::vector<array_ptr> &arrays,
const array_names &names)
{
for (const auto &arr : arrays) {
if (arr != nullptr && !arr->is_writable()) {
Expand All @@ -77,8 +70,8 @@ void check_writable(const std::vector<array_ptr> &arrays,
}
}

void check_c_contig(const std::vector<array_ptr> &arrays,
const array_names &names)
inline void check_c_contig(const std::vector<array_ptr> &arrays,
const array_names &names)
{
for (const auto &arr : arrays) {
if (arr != nullptr && !arr->is_c_contiguous()) {
Expand All @@ -88,9 +81,9 @@ void check_c_contig(const std::vector<array_ptr> &arrays,
}
}

void check_queue(const std::vector<array_ptr> &arrays,
const array_names &names,
const sycl::queue &exec_q)
inline void check_queue(const std::vector<array_ptr> &arrays,
const array_names &names,
const sycl::queue &exec_q)
{
auto unequal_queue =
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
Expand All @@ -104,9 +97,9 @@ void check_queue(const std::vector<array_ptr> &arrays,
}
}

void check_no_overlap(const array_ptr &input,
const array_ptr &output,
const array_names &names)
inline void check_no_overlap(const array_ptr &input,
const array_ptr &output,
const array_names &names)
{
if (input == nullptr || output == nullptr) {
return;
Expand All @@ -121,9 +114,9 @@ void check_no_overlap(const array_ptr &input,
}
}

void check_no_overlap(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names)
inline void check_no_overlap(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names)
{
for (const auto &input : inputs) {
for (const auto &output : outputs) {
Expand All @@ -132,9 +125,9 @@ void check_no_overlap(const std::vector<array_ptr> &inputs,
}
}

void check_num_dims(const array_ptr &arr,
const size_t ndim,
const array_names &names)
inline void check_num_dims(const array_ptr &arr,
const size_t ndim,
const array_names &names)
{
size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
if (arr != nullptr && arr_n_dim != ndim) {
Expand All @@ -144,9 +137,9 @@ void check_num_dims(const array_ptr &arr,
}
}

void check_max_dims(const array_ptr &arr,
const size_t max_ndim,
const array_names &names)
inline void check_max_dims(const array_ptr &arr,
const size_t max_ndim,
const array_names &names)
{
size_t arr_n_dim = arr != nullptr ? arr->get_ndim() : 0;
if (arr != nullptr && arr_n_dim > max_ndim) {
Expand All @@ -157,9 +150,9 @@ void check_max_dims(const array_ptr &arr,
}
}

void check_size_at_least(const array_ptr &arr,
const size_t size,
const array_names &names)
inline void check_size_at_least(const array_ptr &arr,
const size_t size,
const array_names &names)
{
size_t arr_size = arr != nullptr ? arr->get_size() : 0;
if (arr != nullptr && arr_size < size) {
Expand All @@ -170,9 +163,9 @@ void check_size_at_least(const array_ptr &arr,
}
}

void common_checks(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names)
inline void common_checks(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names)
{
check_writable(outputs, names);

Expand All @@ -187,4 +180,4 @@ void common_checks(const std::vector<array_ptr> &inputs,
check_no_overlap(inputs, outputs, names);
}

} // namespace statistics::validation
} // namespace ext::validation
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
#include <pybind11/stl.h>
#include <sycl/sycl.hpp>

#include "common.hpp"
#include "ext/common.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
namespace py = pybind11;

namespace statistics::common
namespace ext::common
{
template <typename T, typename Rest>
struct one_of
Expand Down Expand Up @@ -383,4 +383,4 @@ class DispatchTable2
Table2<FnT> table;
};

} // namespace statistics::common
} // namespace ext::common
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

#include "dpctl4pybind11.hpp"

namespace statistics::validation
namespace ext::validation
{
using array_ptr = const dpctl::tensor::usm_ndarray *;
using array_names = std::unordered_map<array_ptr, std::string>;
Expand Down Expand Up @@ -67,4 +67,6 @@ void check_size_at_least(const array_ptr &arr,
void common_checks(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names);
} // namespace statistics::validation
} // namespace ext::validation

#include "ext/details/validation_utils_internal.hpp"
3 changes: 1 addition & 2 deletions dpnp/backend/extensions/statistics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@
set(python_module_name _statistics_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down Expand Up @@ -66,6 +64,7 @@ set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDEN

target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)

target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})
Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/statistics/bincount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
using dpctl::tensor::usm_ndarray;

using namespace statistics::histogram;
using namespace statistics::common;
using namespace ext::common;

namespace
{
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/statistics/bincount.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"
#include "ext/dispatch_table.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

Expand All @@ -46,7 +46,7 @@ struct Bincount
const size_t,
const std::vector<sycl::event> &);

common::DispatchTable2<FnT> dispatch_table;
ext::common::DispatchTable2<FnT> dispatch_table;

Bincount();

Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/statistics/histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
using dpctl::tensor::usm_ndarray;

using namespace statistics::histogram;
using namespace statistics::common;
using namespace ext::common;

namespace
{
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/statistics/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"
#include "ext/dispatch_table.hpp"

namespace statistics::histogram
{
Expand All @@ -44,7 +44,7 @@ struct Histogram
const size_t,
const std::vector<sycl::event> &);

common::DispatchTable2<FnT> dispatch_table;
ext::common::DispatchTable2<FnT> dispatch_table;

Histogram();

Expand Down
Loading
Loading