Skip to content

Commit 05e1bb6

Browse files
authored
Rework implementation of dpnp.fmax and dpnp.fmin functions (#1905)
* Implement dpnp.fmax and dpnp.fmin functions * Updated existing tests and added new ones * Removed unused code from cython backend * Removed a reference to original descriptor
1 parent 740b08b commit 05e1bb6

30 files changed

+1154
-501
lines changed

doc/reference/ufunc.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,12 @@ Comparison functions
105105
dpnp.less_equal
106106
dpnp.not_equal
107107
dpnp.equal
108+
108109
dpnp.logical_and
109110
dpnp.logical_or
110111
dpnp.logical_xor
111112
dpnp.logical_not
113+
112114
dpnp.maximum
113115
dpnp.minimum
114116
dpnp.fmax

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
set(_elementwise_sources
2727
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/common.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fabs.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmax.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmin.cpp
2931
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmod.cpp
3032
)
3133

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <pybind11/pybind11.h>
2727

2828
#include "fabs.hpp"
29+
#include "fmax.hpp"
30+
#include "fmin.hpp"
2931
#include "fmod.hpp"
3032

3133
namespace py = pybind11;
@@ -38,6 +40,8 @@ namespace dpnp::extensions::ufunc
3840
void init_elementwise_functions(py::module_ m)
3941
{
4042
init_fabs(m);
43+
init_fmax(m);
44+
init_fmin(m);
4145
init_fmod(m);
4246
}
4347
} // namespace dpnp::extensions::ufunc
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// maxification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <sycl/sycl.hpp>
27+
28+
#include "dpctl4pybind11.hpp"
29+
30+
#include "fmax.hpp"
31+
#include "kernels/elementwise_functions/fmax.hpp"
32+
#include "populate.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "kernels/elementwise_functions/maximum.hpp"
42+
#include "utils/type_dispatch.hpp"
43+
44+
namespace py = pybind11;
45+
46+
namespace dpnp::extensions::ufunc
47+
{
48+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
49+
namespace max_ns = dpctl::tensor::kernels::maximum;
50+
namespace py_int = dpnp::extensions::py_internal;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
54+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
55+
56+
namespace impl
57+
{
58+
// Supports the same types table as for maximum function in dpctl
59+
template <typename T1, typename T2>
60+
using OutputType = max_ns::MaximumOutputType<T1, T2>;
61+
62+
using dpnp::kernels::fmax::FmaxFunctor;
63+
64+
template <typename argT1,
65+
typename argT2,
66+
typename resT,
67+
unsigned int vec_sz = 4,
68+
unsigned int n_vecs = 2,
69+
bool enable_sg_loadstore = true>
70+
using ContigFunctor =
71+
ew_cmn_ns::BinaryContigFunctor<argT1,
72+
argT2,
73+
resT,
74+
FmaxFunctor<argT1, argT2, resT>,
75+
vec_sz,
76+
n_vecs,
77+
enable_sg_loadstore>;
78+
79+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
80+
using StridedFunctor =
81+
ew_cmn_ns::BinaryStridedFunctor<argT1,
82+
argT2,
83+
resT,
84+
IndexerT,
85+
FmaxFunctor<argT1, argT2, resT>>;
86+
87+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
88+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
89+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
90+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
91+
92+
static binary_contig_impl_fn_ptr_t fmax_contig_dispatch_table[td_ns::num_types]
93+
[td_ns::num_types];
94+
static int fmax_output_typeid_table[td_ns::num_types][td_ns::num_types];
95+
static binary_strided_impl_fn_ptr_t
96+
fmax_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
97+
98+
MACRO_POPULATE_DISPATCH_TABLES(fmax);
99+
} // namespace impl
100+
101+
void init_fmax(py::module_ m)
102+
{
103+
using arrayT = dpctl::tensor::usm_ndarray;
104+
using event_vecT = std::vector<sycl::event>;
105+
{
106+
impl::populate_fmax_dispatch_tables();
107+
using impl::fmax_contig_dispatch_table;
108+
using impl::fmax_output_typeid_table;
109+
using impl::fmax_strided_dispatch_table;
110+
111+
auto fmax_pyapi = [&](const arrayT &src1, const arrayT &src2,
112+
const arrayT &dst, sycl::queue &exec_q,
113+
const event_vecT &depends = {}) {
114+
return py_int::py_binary_ufunc(
115+
src1, src2, dst, exec_q, depends, fmax_output_typeid_table,
116+
fmax_contig_dispatch_table, fmax_strided_dispatch_table,
117+
// no support of C-contig row with broadcasting in OneMKL
118+
td_ns::NullPtrTable<
119+
impl::
120+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
121+
td_ns::NullPtrTable<
122+
impl::
123+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
124+
};
125+
m.def("_fmax", fmax_pyapi, "", py::arg("src1"), py::arg("src2"),
126+
py::arg("dst"), py::arg("sycl_queue"),
127+
py::arg("depends") = py::list());
128+
129+
auto fmax_result_type_pyapi = [&](const py::dtype &dtype1,
130+
const py::dtype &dtype2) {
131+
return py_int::py_binary_ufunc_result_type(
132+
dtype1, dtype2, fmax_output_typeid_table);
133+
};
134+
m.def("_fmax_result_type", fmax_result_type_pyapi);
135+
}
136+
}
137+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_fmax(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// maxification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <sycl/sycl.hpp>
27+
28+
#include "dpctl4pybind11.hpp"
29+
30+
#include "fmin.hpp"
31+
#include "kernels/elementwise_functions/fmin.hpp"
32+
#include "populate.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "kernels/elementwise_functions/minimum.hpp"
42+
#include "utils/type_dispatch.hpp"
43+
44+
namespace py = pybind11;
45+
46+
namespace dpnp::extensions::ufunc
47+
{
48+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
49+
namespace min_ns = dpctl::tensor::kernels::minimum;
50+
namespace py_int = dpnp::extensions::py_internal;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
54+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
55+
56+
namespace impl
57+
{
58+
// Supports the same types table as for minimum function in dpctl
59+
template <typename T1, typename T2>
60+
using OutputType = min_ns::MinimumOutputType<T1, T2>;
61+
62+
using dpnp::kernels::fmin::FminFunctor;
63+
64+
template <typename argT1,
65+
typename argT2,
66+
typename resT,
67+
unsigned int vec_sz = 4,
68+
unsigned int n_vecs = 2,
69+
bool enable_sg_loadstore = true>
70+
using ContigFunctor =
71+
ew_cmn_ns::BinaryContigFunctor<argT1,
72+
argT2,
73+
resT,
74+
FminFunctor<argT1, argT2, resT>,
75+
vec_sz,
76+
n_vecs,
77+
enable_sg_loadstore>;
78+
79+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
80+
using StridedFunctor =
81+
ew_cmn_ns::BinaryStridedFunctor<argT1,
82+
argT2,
83+
resT,
84+
IndexerT,
85+
FminFunctor<argT1, argT2, resT>>;
86+
87+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
88+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
89+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
90+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
91+
92+
static binary_contig_impl_fn_ptr_t fmin_contig_dispatch_table[td_ns::num_types]
93+
[td_ns::num_types];
94+
static int fmin_output_typeid_table[td_ns::num_types][td_ns::num_types];
95+
static binary_strided_impl_fn_ptr_t
96+
fmin_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
97+
98+
MACRO_POPULATE_DISPATCH_TABLES(fmin);
99+
} // namespace impl
100+
101+
void init_fmin(py::module_ m)
102+
{
103+
using arrayT = dpctl::tensor::usm_ndarray;
104+
using event_vecT = std::vector<sycl::event>;
105+
{
106+
impl::populate_fmin_dispatch_tables();
107+
using impl::fmin_contig_dispatch_table;
108+
using impl::fmin_output_typeid_table;
109+
using impl::fmin_strided_dispatch_table;
110+
111+
auto fmin_pyapi = [&](const arrayT &src1, const arrayT &src2,
112+
const arrayT &dst, sycl::queue &exec_q,
113+
const event_vecT &depends = {}) {
114+
return py_int::py_binary_ufunc(
115+
src1, src2, dst, exec_q, depends, fmin_output_typeid_table,
116+
fmin_contig_dispatch_table, fmin_strided_dispatch_table,
117+
// no support of C-contig row with broadcasting in OneMKL
118+
td_ns::NullPtrTable<
119+
impl::
120+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
121+
td_ns::NullPtrTable<
122+
impl::
123+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
124+
};
125+
m.def("_fmin", fmin_pyapi, "", py::arg("src1"), py::arg("src2"),
126+
py::arg("dst"), py::arg("sycl_queue"),
127+
py::arg("depends") = py::list());
128+
129+
auto fmin_result_type_pyapi = [&](const py::dtype &dtype1,
130+
const py::dtype &dtype2) {
131+
return py_int::py_binary_ufunc_result_type(
132+
dtype1, dtype2, fmin_output_typeid_table);
133+
};
134+
m.def("_fmin_result_type", fmin_result_type_pyapi);
135+
}
136+
}
137+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_fmin(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ set(_elementwise_sources
4343
${CMAKE_CURRENT_SOURCE_DIR}/exp2.cpp
4444
${CMAKE_CURRENT_SOURCE_DIR}/expm1.cpp
4545
${CMAKE_CURRENT_SOURCE_DIR}/floor.cpp
46+
${CMAKE_CURRENT_SOURCE_DIR}/fmax.cpp
47+
${CMAKE_CURRENT_SOURCE_DIR}/fmin.cpp
4648
${CMAKE_CURRENT_SOURCE_DIR}/fmod.cpp
4749
${CMAKE_CURRENT_SOURCE_DIR}/hypot.cpp
4850
${CMAKE_CURRENT_SOURCE_DIR}/ln.cpp

0 commit comments

Comments
 (0)