Skip to content

Commit b809fc0

Browse files
Add validation in c++
1 parent 07c3a4a commit b809fc0

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(_module_src
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3333
${CMAKE_CURRENT_SOURCE_DIR}/kth_element1d.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp
3435
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
3536
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
3637
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp

dpnp/backend/extensions/statistics/kth_element1d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ KthElement1d::RetT KthElement1d::call(const dpctl::tensor::usm_ndarray &a,
477477
const size_t k,
478478
const std::vector<sycl::event> &depends)
479479
{
480-
// validate(a, partitioned, k);
480+
validate(a, partitioned, k);
481481

482482
const int a_typenum = a.get_typenum();
483483
auto kth_elem_func = dispatch_table.get(a_typenum);
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024-2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <string>
27+
#include <vector>
28+
29+
#include "dpctl4pybind11.hpp"
30+
#include "utils/type_dispatch.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
#include "sliding_window1d.hpp"
34+
#include "ext/common.hpp"
35+
#include "ext/validation_utils.hpp"
36+
37+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
38+
using namespace ext::common;
39+
using namespace ext::validation;
40+
41+
using dpctl::tensor::usm_ndarray;
42+
using dpctl_td_ns::typenum_t;
43+
44+
namespace statistics::partitioning
45+
{
46+
47+
void validate(const usm_ndarray &a,
48+
const usm_ndarray &partitioned,
49+
const size_t k)
50+
{
51+
array_names names = {
52+
{&a, "a"},
53+
{&partitioned, "partitioned"}
54+
};
55+
56+
common_checks({&a}, {&partitioned}, names);
57+
check_same_size(&a, &partitioned, names);
58+
check_num_dims(&a, 1, names);
59+
check_num_dims(&partitioned, 1, names);
60+
check_same_dtype(&a, &partitioned, names);
61+
62+
if (k > a.get_size() - 2) {
63+
throw py::value_error("'k' must be from 0 to a.size() - 2, "
64+
"but got k = " + std::to_string(k) +
65+
" and a.size() = " + std::to_string(a.get_size()));
66+
}
67+
}
68+
69+
} // namespace statistics::partitioning

dpnp/backend/extensions/statistics/partitioning.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,8 @@ sycl::event run_partition_one_pivot(sycl::queue &exec_q,
224224
deps, group_size);
225225
}
226226
}
227+
228+
void validate(const usm_ndarray &a,
229+
const usm_ndarray &partitioned,
230+
const size_t k);
227231
} // namespace statistics::partitioning

0 commit comments

Comments
 (0)