Skip to content

Commit cdce27a

Browse files
committed
Add DPNPUnaryTwoOutputsFunc class for unary element-wise functions with two output arrays
1 parent 7a9de99 commit cdce27a

File tree

2 files changed

+255
-3
lines changed

2 files changed

+255
-3
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
# *****************************************************************************
2929

3030
import dpctl.tensor as dpt
31+
import dpctl.tensor._copy_utils as dtc
3132
import dpctl.tensor._tensor_impl as dti
3233
import dpctl.tensor._type_utils as dtu
34+
import dpctl.utils as dpu
3335
import numpy
3436
from dpctl.tensor._elementwise_common import (
3537
BinaryElementwiseFunc,
@@ -39,6 +41,9 @@
3941
import dpnp
4042
import dpnp.backend.extensions.vm._vm_impl as vmi
4143
from dpnp.dpnp_array import dpnp_array
44+
from dpnp.dpnp_utils.dpnp_utils_common import (
45+
find_buf_dtype_3out,
46+
)
4247

4348
__all__ = [
4449
"DPNPI0",
@@ -50,6 +55,7 @@
5055
"DPNPRound",
5156
"DPNPSinc",
5257
"DPNPUnaryFunc",
58+
"DPNPUnaryTwoOutputsFunc",
5359
"acceptance_fn_gcd_lcm",
5460
"acceptance_fn_negative",
5561
"acceptance_fn_positive",
@@ -102,6 +108,7 @@ class DPNPUnaryFunc(UnaryElementwiseFunc):
102108
The function is invoked when the argument of the unary function
103109
requires casting, e.g. the argument of `dpctl.tensor.log` is an
104110
array with integral data type.
111+
105112
"""
106113

107114
def __init__(
@@ -197,6 +204,227 @@ def __call__(
197204
return dpnp_array._create_from_usm_ndarray(res_usm)
198205

199206

207+
class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
208+
"""
209+
Class that implements unary element-wise functions with two output arrays.
210+
211+
Parameters
212+
----------
213+
name : {str}
214+
Name of the unary function
215+
result_type_resovler_fn : {callable}
216+
Function that takes dtype of the input and returns the dtype of
217+
the result if the implementation functions supports it, or
218+
returns `None` otherwise.
219+
unary_dp_impl_fn : {callable}
220+
Data-parallel implementation function with signature
221+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
222+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
223+
where the `src` is the argument array, `dst` is the
224+
array to be populated with function values, effectively
225+
evaluating `dst = func(src)`.
226+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
227+
The first event corresponds to data-management host tasks,
228+
including lifetime management of argument Python objects to ensure
229+
that their associated USM allocation is not freed before offloaded
230+
computational tasks complete execution, while the second event
231+
corresponds to computational tasks associated with function evaluation.
232+
docs : {str}
233+
Documentation string for the unary function.
234+
235+
"""
236+
237+
def __init__(
238+
self,
239+
name,
240+
result_type_resolver_fn,
241+
unary_dp_impl_fn,
242+
docs,
243+
):
244+
super().__init__(
245+
name,
246+
result_type_resolver_fn,
247+
unary_dp_impl_fn,
248+
docs,
249+
)
250+
self.__name__ = "DPNPUnaryTwoOutputsFunc"
251+
252+
@property
253+
def nout(self):
254+
"""Returns the number of arguments treated as outputs."""
255+
return 2
256+
257+
def __call__(
258+
self,
259+
x,
260+
out1=None,
261+
out2=None,
262+
/,
263+
*,
264+
out=(None, None),
265+
where=True,
266+
order="K",
267+
dtype=None,
268+
subok=True,
269+
**kwargs,
270+
):
271+
if kwargs:
272+
raise NotImplementedError(
273+
f"Requested function={self.name_} with kwargs={kwargs} "
274+
"isn't currently supported."
275+
)
276+
elif where is not True:
277+
raise NotImplementedError(
278+
f"Requested function={self.name_} with where={where} "
279+
"isn't currently supported."
280+
)
281+
elif dtype is not None:
282+
raise NotImplementedError(
283+
f"Requested function={self.name_} with dtype={dtype} "
284+
"isn't currently supported."
285+
)
286+
elif subok is not True:
287+
raise NotImplementedError(
288+
f"Requested function={self.name_} with subok={subok} "
289+
"isn't currently supported."
290+
)
291+
292+
x = dpnp.get_usm_ndarray(x)
293+
exec_q = x.sycl_queue
294+
295+
if order is None:
296+
order = "K"
297+
elif order in "afkcAFKC":
298+
order = order.upper()
299+
if order == "A":
300+
order = "F" if x.flags.f_contiguous else "C"
301+
else:
302+
raise ValueError(
303+
"order must be one of 'C', 'F', 'A', or 'K' " f"(got '{order}')"
304+
)
305+
306+
buf_dt, res1_dt, res2_dt = find_buf_dtype_3out(
307+
x.dtype,
308+
self.result_type_resolver_fn_,
309+
x.sycl_device,
310+
)
311+
if res1_dt is None or res2_dt is None:
312+
raise ValueError(
313+
f"function '{self.name_}' does not support input type "
314+
f"({x.dtype}), "
315+
"and the input could not be safely coerced to any "
316+
"supported types according to the casting rule ''safe''."
317+
)
318+
319+
if not isinstance(out, tuple):
320+
raise TypeError("'out' must be a tuple of arrays")
321+
322+
if len(out) != 2:
323+
raise ValueError(
324+
"'out' tuple must have exactly one entry per ufunc output"
325+
)
326+
327+
if not (out1 is None and out2 is None):
328+
if all(res is None for res in out):
329+
out = (out1, out2)
330+
else:
331+
raise TypeError(
332+
"cannot specify 'out' as both a positional and keyword argument"
333+
)
334+
335+
orig_out, out = list(out), list(out)
336+
res_dts = [res1_dt, res2_dt]
337+
338+
for i in range(2):
339+
if out[i] is None:
340+
continue
341+
342+
res = dpnp.get_usm_ndarray(out[i])
343+
if not res.flags.writable:
344+
raise ValueError("provided output array is read-only")
345+
346+
if res.shape != x.shape:
347+
raise ValueError(
348+
"The shape of input and output arrays are inconsistent. "
349+
f"Expected output shape is {x.shape}, got {res.shape}"
350+
)
351+
352+
if dpu.get_execution_queue((exec_q, res.sycl_queue)) is None:
353+
raise dpnp.exceptions.ExecutionPlacementError(
354+
"Input and output allocation queues are not compatible"
355+
)
356+
357+
res_dt = res_dts[i]
358+
if res_dt != res.dtype:
359+
# Allocate a temporary buffer with the required dtype
360+
out[i] = dpt.empty_like(res, dtype=res_dt)
361+
elif (
362+
buf_dt is None
363+
and dti._array_overlap(x, res)
364+
and not dti._same_logical_tensors(x, res)
365+
):
366+
# Allocate a temporary buffer to avoid memory overlapping.
367+
# Note if `buf_dt` is not None, a temporary copy of `x` will be
368+
# created, so the array overlap check isn't needed.
369+
out[i] = dpt.empty_like(res)
370+
371+
_manager = dpu.SequentialOrderManager[exec_q]
372+
dep_evs = _manager.submitted_events
373+
374+
# Cast input array to the supported type if needed
375+
if buf_dt is not None:
376+
if order == "K":
377+
buf = dtc._empty_like_orderK(x, buf_dt)
378+
else:
379+
buf = dpt.empty_like(x, dtype=buf_dt, order=order)
380+
381+
ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray(
382+
src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
383+
)
384+
_manager.add_event_pair(ht_copy_ev, copy_ev)
385+
386+
x = buf
387+
dep_evs = copy_ev
388+
389+
# Allocate a buffer for the output arrays if needed
390+
for i in range(2):
391+
if out[i] is None:
392+
res_dt = res_dts[i]
393+
if order == "K":
394+
out[i] = dtc._empty_like_orderK(x, res_dt)
395+
else:
396+
out[i] = dpt.empty_like(x, dtype=res_dt, order=order)
397+
398+
# Call the unary function with input and output arrays
399+
dep_evs = _manager.submitted_events
400+
ht_unary_ev, unary_ev = self.get_implementation_function()(
401+
x,
402+
dpnp.get_usm_ndarray(out[0]),
403+
dpnp.get_usm_ndarray(out[1]),
404+
sycl_queue=exec_q,
405+
depends=dep_evs,
406+
)
407+
_manager.add_event_pair(ht_unary_ev, unary_ev)
408+
409+
for i in range(2):
410+
orig_res, res = orig_out[i], out[i]
411+
if not (orig_res is None or orig_res is res):
412+
# Copy the out data from temporary buffer to original memory
413+
ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray(
414+
src=res,
415+
dst=dpnp.get_usm_ndarray(orig_res),
416+
sycl_queue=exec_q,
417+
depends=[unary_ev],
418+
)
419+
_manager.add_event_pair(ht_copy_ev, copy_ev)
420+
res = out[i] = orig_res
421+
422+
if not isinstance(res, dpnp_array):
423+
# Always return dpnp.ndarray
424+
out[i] = dpnp_array._create_from_usm_ndarray(res)
425+
return out
426+
427+
200428
class DPNPBinaryFunc(BinaryElementwiseFunc):
201429
"""
202430
Class that implements binary element-wise functions.
@@ -262,6 +490,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
262490
sycl_dev - The :class:`dpctl.SyclDevice` where the function
263491
evaluation is carried out.
264492
One of `o1_dtype` and `o2_dtype` must be a ``dtype`` instance.
493+
265494
"""
266495

267496
def __init__(

dpnp/dpnp_utils/dpnp_utils_common.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,35 @@
2929

3030
from collections.abc import Iterable
3131

32-
from dpctl.tensor._type_utils import _can_cast
32+
import dpctl.tensor._type_utils as dtu
3333

3434
import dpnp
3535
from dpnp.dpnp_utils import map_dtype_to_device
3636

37-
__all__ = ["result_type_for_device", "to_supported_dtypes"]
37+
__all__ = [
38+
"find_buf_dtype_3out",
39+
"result_type_for_device",
40+
"to_supported_dtypes",
41+
]
42+
43+
44+
def find_buf_dtype_3out(arg_dtype, query_fn, sycl_dev):
45+
"""Works as dpu._find_buf_dtype, but with two output arrays."""
46+
47+
res1_dt, res2_dt = query_fn(arg_dtype)
48+
if res1_dt and res2_dt:
49+
return None, res1_dt, res2_dt
50+
51+
_fp16 = sycl_dev.has_aspect_fp16
52+
_fp64 = sycl_dev.has_aspect_fp64
53+
all_dts = dtu._all_data_types(_fp16, _fp64)
54+
for buf_dt in all_dts:
55+
if dtu._can_cast(arg_dtype, buf_dt, _fp16, _fp64):
56+
res1_dt, res2_dt = query_fn(buf_dt)
57+
if res1_dt and res2_dt:
58+
return buf_dt, res1_dt, res2_dt
59+
60+
return None, None, None
3861

3962

4063
def result_type_for_device(dtypes, device):
@@ -55,7 +78,7 @@ def to_supported_dtypes(dtypes, supported_types, device):
5578
has_fp16 = device.has_aspect_fp16
5679

5780
def is_castable(dtype, stype):
58-
return _can_cast(dtype, stype, has_fp16, has_fp64)
81+
return dtu._can_cast(dtype, stype, has_fp16, has_fp64)
5982

6083
if not isinstance(supported_types, Iterable):
6184
supported_types = (supported_types,) # pragma: no cover

0 commit comments

Comments
 (0)