Skip to content

Commit 69d420e

Browse files
authored
Merge pull request #1214 from vlad-perevezentsev/not_equal_impl
Implementation of dpctl.tensor.not_equal function
2 parents 6bb09e7 + cee2e0a commit 69d420e

File tree

6 files changed

+562
-3
lines changed

6 files changed

+562
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
log,
108108
log1p,
109109
multiply,
110+
not_equal,
110111
proj,
111112
real,
112113
sin,
@@ -210,5 +211,6 @@
210211
"multiply",
211212
"subtract",
212213
"equal",
214+
"not_equal",
213215
"sum",
214216
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,33 @@
480480
# FIXME: implement U25
481481

482482
# B20: ==== NOT_EQUAL (x1, x2)
483-
# FIXME: implement B20
483+
_not_equal_docstring_ = """
484+
not_equal(x1, x2, out=None, order='K')
485+
486+
Calculates inequality test results for each element `x1_i` of the
487+
input array `x1` with the respective element `x2_i` of the input array `x2`.
488+
489+
Args:
490+
x1 (usm_ndarray):
491+
First input array, expected to have numeric data type.
492+
x2 (usm_ndarray):
493+
Second input array, also expected to have numeric data type.
494+
out ({None, usm_ndarray}, optional):
495+
Output array to populate.
496+
Array have the correct shape and the expected data type.
497+
order ("C","F","A","K", optional):
498+
Memory layout of the newly output array, if parameter `out` is `None`.
499+
Default: "K".
500+
Returns:
501+
usm_narray:
502+
an array containing the result of element-wise inequality comparison.
503+
The data type of the returned array is determined by the
504+
Type Promotion Rules.
505+
"""
506+
507+
not_equal = BinaryElementwiseFunc(
508+
"not_equal", ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_
509+
)
484510

485511
# U26: ==== POSITIVE (x)
486512
# FIXME: implement U26
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
//=== not_equal.hpp - Binary function NOT_EQUAL ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of inequality of
23+
/// tensor elements.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace not_equal
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct NotEqualFunctor
53+
{
54+
static_assert(std::is_same_v<resT, bool>);
55+
56+
using supports_sg_loadstore = std::negation<
57+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58+
using supports_vec = std::conjunction<
59+
std::is_same<argT1, argT2>,
60+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
61+
tu_ns::is_complex<argT2>>>>;
62+
63+
resT operator()(const argT1 &in1, const argT2 &in2)
64+
{
65+
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
66+
std::is_same_v<argT2, float>)
67+
{
68+
return (std::real(in1) != in2 || std::imag(in1) != 0.0f);
69+
}
70+
else if constexpr (std::is_same_v<argT1, float> &&
71+
std::is_same_v<argT2, std::complex<float>>)
72+
{
73+
return (in1 != std::real(in2) || std::imag(in2) != 0.0f);
74+
}
75+
else {
76+
return (in1 != in2);
77+
}
78+
}
79+
80+
template <int vec_sz>
81+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
82+
const sycl::vec<argT2, vec_sz> &in2)
83+
{
84+
auto tmp = (in1 != in2);
85+
if constexpr (std::is_same_v<resT,
86+
typename decltype(tmp)::element_type>) {
87+
return tmp;
88+
}
89+
else {
90+
using dpctl::tensor::type_utils::vec_cast;
91+
92+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
93+
tmp);
94+
}
95+
}
96+
};
97+
98+
template <typename argT1,
99+
typename argT2,
100+
typename resT,
101+
unsigned int vec_sz = 4,
102+
unsigned int n_vecs = 2>
103+
using NotEqualContigFunctor =
104+
elementwise_common::BinaryContigFunctor<argT1,
105+
argT2,
106+
resT,
107+
NotEqualFunctor<argT1, argT2, resT>,
108+
vec_sz,
109+
n_vecs>;
110+
111+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
112+
using NotEqualStridedFunctor = elementwise_common::BinaryStridedFunctor<
113+
argT1,
114+
argT2,
115+
resT,
116+
IndexerT,
117+
NotEqualFunctor<argT1, argT2, resT>>;
118+
119+
template <typename T1, typename T2> struct NotEqualOutputType
120+
{
121+
using value_type = typename std::disjunction< // disjunction is C++17
122+
// feature, supported by DPC++
123+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
124+
td_ns::
125+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
126+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
127+
td_ns::BinaryTypeMapResultEntry<T1,
128+
std::uint16_t,
129+
T2,
130+
std::uint16_t,
131+
bool>,
132+
td_ns::
133+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
134+
td_ns::BinaryTypeMapResultEntry<T1,
135+
std::uint32_t,
136+
T2,
137+
std::uint32_t,
138+
bool>,
139+
td_ns::
140+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
141+
td_ns::BinaryTypeMapResultEntry<T1,
142+
std::uint64_t,
143+
T2,
144+
std::uint64_t,
145+
bool>,
146+
td_ns::
147+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
148+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
149+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
150+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
151+
td_ns::BinaryTypeMapResultEntry<T1,
152+
std::complex<float>,
153+
T2,
154+
std::complex<float>,
155+
bool>,
156+
td_ns::BinaryTypeMapResultEntry<T1,
157+
std::complex<double>,
158+
T2,
159+
std::complex<double>,
160+
bool>,
161+
td_ns::
162+
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
163+
td_ns::
164+
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
165+
td_ns::DefaultResultEntry<void>>::result_type;
166+
};
167+
168+
template <typename argT1,
169+
typename argT2,
170+
typename resT,
171+
unsigned int vec_sz,
172+
unsigned int n_vecs>
173+
class not_equal_contig_kernel;
174+
175+
template <typename argTy1, typename argTy2>
176+
sycl::event not_equal_contig_impl(sycl::queue exec_q,
177+
size_t nelems,
178+
const char *arg1_p,
179+
py::ssize_t arg1_offset,
180+
const char *arg2_p,
181+
py::ssize_t arg2_offset,
182+
char *res_p,
183+
py::ssize_t res_offset,
184+
const std::vector<sycl::event> &depends = {})
185+
{
186+
return elementwise_common::binary_contig_impl<
187+
argTy1, argTy2, NotEqualOutputType, NotEqualContigFunctor,
188+
not_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
189+
arg2_offset, res_p, res_offset, depends);
190+
}
191+
192+
template <typename fnT, typename T1, typename T2> struct NotEqualContigFactory
193+
{
194+
fnT get()
195+
{
196+
if constexpr (std::is_same_v<
197+
typename NotEqualOutputType<T1, T2>::value_type,
198+
void>)
199+
{
200+
fnT fn = nullptr;
201+
return fn;
202+
}
203+
else {
204+
fnT fn = not_equal_contig_impl<T1, T2>;
205+
return fn;
206+
}
207+
}
208+
};
209+
210+
template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
211+
{
212+
/*! @brief get typeid for output type of operator()!=(x, y), always bool */
213+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
214+
{
215+
using rT = typename NotEqualOutputType<T1, T2>::value_type;
216+
return td_ns::GetTypeid<rT>{}.get();
217+
}
218+
};
219+
220+
template <typename T1, typename T2, typename resT, typename IndexerT>
221+
class not_equal_strided_strided_kernel;
222+
223+
template <typename argTy1, typename argTy2>
224+
sycl::event
225+
not_equal_strided_impl(sycl::queue exec_q,
226+
size_t nelems,
227+
int nd,
228+
const py::ssize_t *shape_and_strides,
229+
const char *arg1_p,
230+
py::ssize_t arg1_offset,
231+
const char *arg2_p,
232+
py::ssize_t arg2_offset,
233+
char *res_p,
234+
py::ssize_t res_offset,
235+
const std::vector<sycl::event> &depends,
236+
const std::vector<sycl::event> &additional_depends)
237+
{
238+
return elementwise_common::binary_strided_impl<
239+
argTy1, argTy2, NotEqualOutputType, NotEqualStridedFunctor,
240+
not_equal_strided_strided_kernel>(
241+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
242+
arg2_offset, res_p, res_offset, depends, additional_depends);
243+
}
244+
245+
template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory
246+
{
247+
fnT get()
248+
{
249+
if constexpr (std::is_same_v<
250+
typename NotEqualOutputType<T1, T2>::value_type,
251+
void>)
252+
{
253+
fnT fn = nullptr;
254+
return fn;
255+
}
256+
else {
257+
fnT fn = not_equal_strided_impl<T1, T2>;
258+
return fn;
259+
}
260+
}
261+
};
262+
263+
} // namespace not_equal
264+
} // namespace kernels
265+
} // namespace tensor
266+
} // namespace dpctl

0 commit comments

Comments
 (0)