Skip to content

Commit 741b933

Browse files
committed
Add missing headers in nan_to_num.cpp
1 parent 13ee6a1 commit 741b933

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26+
#include <algorithm>
27+
#include <complex>
2628
#include <stdexcept>
29+
#include <string>
30+
#include <tuple>
31+
#include <type_traits>
32+
#include <utility>
33+
#include <vector>
2734

2835
#include <sycl/sycl.hpp>
2936

@@ -145,12 +152,8 @@ std::pair<sycl::event, sycl::event>
145152
const py::ssize_t *src_shape = src.get_shape_raw();
146153
const py::ssize_t *dst_shape = dst.get_shape_raw();
147154

148-
bool shapes_equal(true);
149-
size_t nelems(1);
150-
for (int i = 0; i < src_nd; ++i) {
151-
nelems *= static_cast<size_t>(src_shape[i]);
152-
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
153-
}
155+
size_t nelems = src.get_size();
156+
bool shapes_equal = std::equal(src_shape, src_shape + src_nd, dst_shape);
154157
if (!shapes_equal) {
155158
throw py::value_error("Array shapes are not the same.");
156159
}

0 commit comments

Comments
 (0)