Skip to content

Commit fbbf1b3

Browse files
Implementation of dpctl.tensor.less_equal
1 parent 47466ee commit fbbf1b3

File tree

5 files changed

+679
-3
lines changed

5 files changed

+679
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
isfinite,
106106
isinf,
107107
isnan,
108+
less_equal,
108109
log,
109110
log1p,
110111
multiply,
@@ -202,6 +203,7 @@
202203
"isinf",
203204
"isnan",
204205
"isfinite",
206+
"less_equal",
205207
"log",
206208
"log1p",
207209
"proj",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,35 @@
408408
# FIXME: implement B13
409409

410410
# B14: ==== LESS_EQUAL (x1, x2)
411-
# FIXME: implement B14
411+
_less_equal_docstring_ = """
412+
less_equal(x1, x2, out=None, order='K')
413+
Computes the less-than or equal-to test results for each element `x1_i` of
414+
the input array `x1` the respective element `x2_i` of the input array `x2`.
415+
Args:
416+
x1 (usm_ndarray):
417+
First input array, expected to have numeric data type.
418+
x2 (usm_ndarray):
419+
Second input array, also expected to have numeric data type.
420+
out ({None, usm_ndarray}, optional):
421+
Output array to populate.
422+
Array have the correct shape and the expected data type.
423+
order ("C","F","A","K", optional):
424+
Memory layout of the newly output array, if parameter `out` is `None`.
425+
Default: "K".
426+
Returns:
427+
usm_narray:
428+
An array containing the result of element-wise less-than or equal-to
429+
comparison.
430+
The data type of the returned array is determined by the
431+
Type Promotion Rules.
432+
"""
433+
434+
less_equal = BinaryElementwiseFunc(
435+
"less_equal",
436+
ti._less_equal_result_type,
437+
ti._less_equal,
438+
_less_equal_docstring_,
439+
)
412440

413441
# U20: ==== LOG (x)
414442
_log_docstring = """
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
//=== less_equal.hpp - Binary function LESS_EQUAL ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain in1 copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of comparison of
24+
/// tensor elements.
25+
//===---------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include <CL/sycl.hpp>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "utils/offset_utils.hpp"
34+
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
36+
37+
#include "kernels/elementwise_functions/common.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace less_equal
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace tu_ns = dpctl::tensor::type_utils;
52+
53+
template <typename argT1, typename argT2, typename resT> struct LessEqualFunctor
54+
{
55+
static_assert(std::is_same_v<resT, bool>);
56+
57+
using supports_sg_loadstore = std::negation<
58+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
using supports_vec = std::conjunction<
60+
std::is_same<argT1, argT2>,
61+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
62+
tu_ns::is_complex<argT2>>>>;
63+
64+
resT operator()(const argT1 &in1, const argT2 &in2)
65+
{
66+
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
67+
std::is_same_v<argT2, float>)
68+
{
69+
float real1 = std::real(in1);
70+
return (real1 == in2) ? (std::imag(in1) <= 0.0f) : real1 <= in2;
71+
}
72+
else if constexpr (std::is_same_v<argT1, float> &&
73+
std::is_same_v<argT2, std::complex<float>>)
74+
{
75+
float real2 = std::real(in2);
76+
return (in1 == real2) ? (0.0f <= std::imag(in2)) : in1 <= real2;
77+
}
78+
else if constexpr (tu_ns::is_complex<argT1>::value ||
79+
tu_ns::is_complex<argT2>::value)
80+
{
81+
static_assert(std::is_same_v<argT1, argT2>);
82+
using realT = typename argT1::value_type;
83+
realT real1 = std::real(in1);
84+
realT real2 = std::real(in2);
85+
86+
return (real1 == real2) ? (std::imag(in1) <= std::imag(in2))
87+
: real1 <= real2;
88+
}
89+
else {
90+
return (in1 <= in2);
91+
}
92+
}
93+
94+
template <int vec_sz>
95+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
96+
const sycl::vec<argT2, vec_sz> &in2)
97+
{
98+
99+
auto tmp = (in1 <= in2);
100+
101+
if constexpr (std::is_same_v<resT,
102+
typename decltype(tmp)::element_type>) {
103+
return tmp;
104+
}
105+
else {
106+
using dpctl::tensor::type_utils::vec_cast;
107+
108+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
109+
tmp);
110+
}
111+
}
112+
};
113+
114+
template <typename argT1,
115+
typename argT2,
116+
typename resT,
117+
unsigned int vec_sz = 4,
118+
unsigned int n_vecs = 2>
119+
using LessEqualContigFunctor = elementwise_common::BinaryContigFunctor<
120+
argT1,
121+
argT2,
122+
resT,
123+
LessEqualFunctor<argT1, argT2, resT>,
124+
vec_sz,
125+
n_vecs>;
126+
127+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
128+
using LessEqualStridedFunctor = elementwise_common::BinaryStridedFunctor<
129+
argT1,
130+
argT2,
131+
resT,
132+
IndexerT,
133+
LessEqualFunctor<argT1, argT2, resT>>;
134+
135+
template <typename T1, typename T2> struct LessEqualOutputType
136+
{
137+
using value_type = typename std::disjunction< // disjunction is C++17
138+
// feature, supported by DPC++
139+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
140+
td_ns::
141+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
142+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
143+
td_ns::BinaryTypeMapResultEntry<T1,
144+
std::uint16_t,
145+
T2,
146+
std::uint16_t,
147+
bool>,
148+
td_ns::
149+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
150+
td_ns::BinaryTypeMapResultEntry<T1,
151+
std::uint32_t,
152+
T2,
153+
std::uint32_t,
154+
bool>,
155+
td_ns::
156+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
157+
td_ns::BinaryTypeMapResultEntry<T1,
158+
std::uint64_t,
159+
T2,
160+
std::uint64_t,
161+
bool>,
162+
td_ns::
163+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
164+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
165+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
166+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
167+
td_ns::BinaryTypeMapResultEntry<T1,
168+
std::complex<float>,
169+
T2,
170+
std::complex<float>,
171+
bool>,
172+
td_ns::BinaryTypeMapResultEntry<T1,
173+
std::complex<double>,
174+
T2,
175+
std::complex<double>,
176+
bool>,
177+
td_ns::
178+
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
179+
td_ns::
180+
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
181+
td_ns::DefaultResultEntry<void>>::result_type;
182+
};
183+
184+
template <typename argT1,
185+
typename argT2,
186+
typename resT,
187+
unsigned int vec_sz,
188+
unsigned int n_vecs>
189+
class less_equal_contig_kernel;
190+
191+
template <typename argTy1, typename argTy2>
192+
sycl::event less_equal_contig_impl(sycl::queue exec_q,
193+
size_t nelems,
194+
const char *arg1_p,
195+
py::ssize_t arg1_offset,
196+
const char *arg2_p,
197+
py::ssize_t arg2_offset,
198+
char *res_p,
199+
py::ssize_t res_offset,
200+
const std::vector<sycl::event> &depends = {})
201+
{
202+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
203+
cgh.depends_on(depends);
204+
205+
size_t lws = 64;
206+
constexpr unsigned int vec_sz = 4;
207+
constexpr unsigned int n_vecs = 2;
208+
const size_t n_groups =
209+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
210+
const auto gws_range = sycl::range<1>(n_groups * lws);
211+
const auto lws_range = sycl::range<1>(lws);
212+
213+
using resTy = typename LessEqualOutputType<argTy1, argTy2>::value_type;
214+
215+
const argTy1 *arg1_tp =
216+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
217+
const argTy2 *arg2_tp =
218+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
219+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
220+
221+
cgh.parallel_for<
222+
less_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
223+
sycl::nd_range<1>(gws_range, lws_range),
224+
LessEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
225+
arg1_tp, arg2_tp, res_tp, nelems));
226+
});
227+
return comp_ev;
228+
}
229+
230+
template <typename fnT, typename T1, typename T2> struct LessEqualContigFactory
231+
{
232+
fnT get()
233+
{
234+
if constexpr (std::is_same_v<
235+
typename LessEqualOutputType<T1, T2>::value_type,
236+
void>)
237+
{
238+
fnT fn = nullptr;
239+
return fn;
240+
}
241+
else {
242+
fnT fn = less_equal_contig_impl<T1, T2>;
243+
return fn;
244+
}
245+
}
246+
};
247+
248+
template <typename fnT, typename T1, typename T2> struct LessEqualTypeMapFactory
249+
{
250+
/*! @brief get typeid for output type of operator()>(x, y), always bool */
251+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
252+
{
253+
using rT = typename LessEqualOutputType<T1, T2>::value_type;
254+
return td_ns::GetTypeid<rT>{}.get();
255+
}
256+
};
257+
258+
template <typename T1, typename T2, typename resT, typename IndexerT>
259+
class less_equal_strided_strided_kernel;
260+
261+
template <typename argTy1, typename argTy2>
262+
sycl::event
263+
less_equal_strided_impl(sycl::queue exec_q,
264+
size_t nelems,
265+
int nd,
266+
const py::ssize_t *shape_and_strides,
267+
const char *arg1_p,
268+
py::ssize_t arg1_offset,
269+
const char *arg2_p,
270+
py::ssize_t arg2_offset,
271+
char *res_p,
272+
py::ssize_t res_offset,
273+
const std::vector<sycl::event> &depends,
274+
const std::vector<sycl::event> &additional_depends)
275+
{
276+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
277+
cgh.depends_on(depends);
278+
cgh.depends_on(additional_depends);
279+
280+
using resTy = typename LessEqualOutputType<argTy1, argTy2>::value_type;
281+
282+
using IndexerT =
283+
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
284+
285+
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
286+
shape_and_strides};
287+
288+
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
289+
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
290+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
291+
292+
cgh.parallel_for<
293+
less_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
294+
{nelems}, LessEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
295+
arg1_tp, arg2_tp, res_tp, indexer));
296+
});
297+
return comp_ev;
298+
}
299+
300+
template <typename fnT, typename T1, typename T2> struct LessEqualStridedFactory
301+
{
302+
fnT get()
303+
{
304+
if constexpr (std::is_same_v<
305+
typename LessEqualOutputType<T1, T2>::value_type,
306+
void>)
307+
{
308+
fnT fn = nullptr;
309+
return fn;
310+
}
311+
else {
312+
fnT fn = less_equal_strided_impl<T1, T2>;
313+
return fn;
314+
}
315+
}
316+
};
317+
318+
} // namespace less_equal
319+
} // namespace kernels
320+
} // namespace tensor
321+
} // namespace dpctl

0 commit comments

Comments
 (0)