Skip to content

Commit 17ee7a4

Browse files
Add implementation of dpctl.tensor.not_equal
1 parent c822d41 commit 17ee7a4

File tree

4 files changed

+381
-3
lines changed

4 files changed

+381
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
isfinite,
101101
isinf,
102102
isnan,
103+
not_equal,
103104
sqrt,
104105
)
105106

@@ -187,4 +188,5 @@
187188
"sqrt",
188189
"divide",
189190
"equal",
191+
"not_equal",
190192
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,27 @@
237237
# FIXME: implement U25
238238

239239
# B20: ==== NOT_EQUAL (x1, x2)
240-
# FIXME: implement B20
240+
_not_equal_docstring_ = """
241+
not_equal(x1, x2, order='K')
242+
243+
Calculates inequality test results for each element `x1_i` of the
244+
input array `x1` with the respective element `x2_i` of the input array `x2`.
245+
246+
Args:
247+
x1 (usm_ndarray):
248+
First input array, expected to have numeric data type.
249+
x2 (usm_ndarray):
250+
Second input array, also expected to have numeric data type.
251+
Returns:
252+
usm_narray:
253+
an array containing the result of element-wise inequality comparison.
254+
The data type of the returned array is determined by the
255+
Type Promotion Rules.
256+
"""
257+
258+
not_equal = BinaryElementwiseFunc(
259+
"not_equal", ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_
260+
)
241261

242262
# U26: ==== POSITIVE (x)
243263
# FIXME: implement U26
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
//=== not_equal.hpp - Binary function NOT_EQUAL ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of inequality of
23+
/// tensor elements.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace not_equal
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct NotEqualFunctor
53+
{
54+
static_assert(std::is_same_v<resT, bool>);
55+
56+
using supports_sg_loadstore = std::negation<
57+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58+
using supports_vec = std::conjunction<
59+
std::is_same<argT1, argT2>,
60+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
61+
tu_ns::is_complex<argT2>>>>;
62+
63+
resT operator()(const argT1 &in1, const argT2 &in2)
64+
{
65+
return (in1 != in2);
66+
}
67+
68+
template <int vec_sz>
69+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
70+
const sycl::vec<argT2, vec_sz> &in2)
71+
{
72+
auto tmp = (in1 != in2);
73+
if constexpr (std::is_same_v<resT,
74+
typename decltype(tmp)::element_type>) {
75+
return tmp;
76+
}
77+
else {
78+
using dpctl::tensor::type_utils::vec_cast;
79+
80+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
81+
tmp);
82+
}
83+
}
84+
};
85+
86+
template <typename argT1,
87+
typename argT2,
88+
typename resT,
89+
unsigned int vec_sz = 4,
90+
unsigned int n_vecs = 2>
91+
using NotEqualContigFunctor =
92+
elementwise_common::BinaryContigFunctor<argT1,
93+
argT2,
94+
resT,
95+
NotEqualFunctor<argT1, argT2, resT>,
96+
vec_sz,
97+
n_vecs>;
98+
99+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
100+
using NotEqualStridedFunctor =
101+
elementwise_common::BinaryStridedFunctor<argT1,
102+
argT2,
103+
resT,
104+
IndexerT,
105+
NotEqualFunctor<argT1, argT2, resT>>;
106+
107+
template <typename T1, typename T2> struct NotEqualOutputType
108+
{
109+
using value_type = typename std::disjunction< // disjunction is C++17
110+
// feature, supported by DPC++
111+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
112+
td_ns::
113+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
114+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
115+
td_ns::BinaryTypeMapResultEntry<T1,
116+
std::uint16_t,
117+
T2,
118+
std::uint16_t,
119+
bool>,
120+
td_ns::
121+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
122+
td_ns::BinaryTypeMapResultEntry<T1,
123+
std::uint32_t,
124+
T2,
125+
std::uint32_t,
126+
bool>,
127+
td_ns::
128+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
129+
td_ns::BinaryTypeMapResultEntry<T1,
130+
std::uint64_t,
131+
T2,
132+
std::uint64_t,
133+
bool>,
134+
td_ns::
135+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
136+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
137+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
138+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
139+
td_ns::BinaryTypeMapResultEntry<T1,
140+
std::complex<float>,
141+
T2,
142+
std::complex<float>,
143+
bool>,
144+
td_ns::BinaryTypeMapResultEntry<T1,
145+
std::complex<double>,
146+
T2,
147+
std::complex<double>,
148+
bool>,
149+
td_ns::DefaultResultEntry<void>>::result_type;
150+
};
151+
152+
template <typename argT1,
153+
typename argT2,
154+
typename resT,
155+
unsigned int vec_sz,
156+
unsigned int n_vecs>
157+
class not_equal_contig_kernel;
158+
159+
template <typename argTy1, typename argTy2>
160+
sycl::event not_equal_contig_impl(sycl::queue exec_q,
161+
size_t nelems,
162+
const char *arg1_p,
163+
py::ssize_t arg1_offset,
164+
const char *arg2_p,
165+
py::ssize_t arg2_offset,
166+
char *res_p,
167+
py::ssize_t res_offset,
168+
const std::vector<sycl::event> &depends = {})
169+
{
170+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
171+
cgh.depends_on(depends);
172+
173+
size_t lws = 64;
174+
constexpr unsigned int vec_sz = 4;
175+
constexpr unsigned int n_vecs = 2;
176+
const size_t n_groups =
177+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
178+
const auto gws_range = sycl::range<1>(n_groups * lws);
179+
const auto lws_range = sycl::range<1>(lws);
180+
181+
using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
182+
183+
const argTy1 *arg1_tp =
184+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
185+
const argTy2 *arg2_tp =
186+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
187+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
188+
189+
cgh.parallel_for<
190+
not_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
191+
sycl::nd_range<1>(gws_range, lws_range),
192+
NotEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
193+
arg1_tp, arg2_tp, res_tp, nelems));
194+
});
195+
return comp_ev;
196+
}
197+
198+
template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
199+
{
200+
fnT get()
201+
{
202+
if constexpr (std::is_same_v<
203+
typename NotEqualOutputType<T1, T2>::value_type, void>)
204+
{
205+
fnT fn = nullptr;
206+
return fn;
207+
}
208+
else {
209+
fnT fn = not_equal_contig_impl<T1, T2>;
210+
return fn;
211+
}
212+
}
213+
};
214+
215+
template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
216+
{
217+
/*! @brief get typeid for output type of operator()==(x, y), always bool */
218+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
219+
{
220+
using rT = typename NotEqualOutputType<T1, T2>::value_type;
221+
return td_ns::GetTypeid<rT>{}.get();
222+
}
223+
};
224+
225+
template <typename T1, typename T2, typename resT, typename IndexerT>
226+
class not_equal_strided_strided_kernel;
227+
228+
template <typename argTy1, typename argTy2>
229+
sycl::event
230+
not_equal_strided_impl(sycl::queue exec_q,
231+
size_t nelems,
232+
int nd,
233+
const py::ssize_t *shape_and_strides,
234+
const char *arg1_p,
235+
py::ssize_t arg1_offset,
236+
const char *arg2_p,
237+
py::ssize_t arg2_offset,
238+
char *res_p,
239+
py::ssize_t res_offset,
240+
const std::vector<sycl::event> &depends,
241+
const std::vector<sycl::event> &additional_depends)
242+
{
243+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
244+
cgh.depends_on(depends);
245+
cgh.depends_on(additional_depends);
246+
247+
using resTy = typename NotEqualOutputType<argTy1, argTy2>::value_type;
248+
249+
using IndexerT =
250+
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
251+
252+
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
253+
shape_and_strides};
254+
255+
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
256+
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
257+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
258+
259+
cgh.parallel_for<
260+
not_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
261+
{nelems}, NotEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
262+
arg1_tp, arg2_tp, res_tp, indexer));
263+
});
264+
return comp_ev;
265+
}
266+
267+
template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory
268+
{
269+
fnT get()
270+
{
271+
if constexpr (std::is_same_v<
272+
typename NotEqualOutputType<T1, T2>::value_type, void>)
273+
{
274+
fnT fn = nullptr;
275+
return fn;
276+
}
277+
else {
278+
fnT fn = not_equal_strided_impl<T1, T2>;
279+
return fn;
280+
}
281+
}
282+
};
283+
284+
} // namespace not_equal
285+
} // namespace kernels
286+
} // namespace tensor
287+
} // namespace dpctl

0 commit comments

Comments
 (0)