Skip to content

Commit e9543f7

Browse files
authored
Leverage dpnp.cumlogsumexp through dpctl.tensor implementation (#1816)
* Implement dpnp.cumprod through dpctl.tensor * Implement dpnp.nancumprod() through existing calls * Implement dpnp.cumlogsumexp through dpctl.tensor * Resolved pre-commit issues * Applied review comment * Fix test_out running on CPU * Generate docstings documentation for dpnp.cumlogsumexp
1 parent 6d181c4 commit e9543f7

File tree

6 files changed

+225
-0
lines changed

6 files changed

+225
-0
lines changed

doc/reference/math.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Exponents and logarithms
9696
dpnp.logaddexp
9797
dpnp.logaddexp2
9898
dpnp.logsumexp
99+
dpnp.cumlogsumexp
99100

100101

101102
Other special functions

dpnp/dpnp_iface_trigonometric.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"cbrt",
7272
"cos",
7373
"cosh",
74+
"cumlogsumexp",
7475
"deg2rad",
7576
"degrees",
7677
"exp",
@@ -665,6 +666,94 @@ def _get_accumulation_res_dt(a, dtype, _out):
665666
)
666667

667668

669+
def cumlogsumexp(
670+
x, /, *, axis=None, dtype=None, include_initial=False, out=None
671+
):
672+
"""
673+
Calculates the cumulative logarithm of the sum of elements in the input
674+
array `x`.
675+
676+
Parameters
677+
----------
678+
x : {dpnp.ndarray, usm_ndarray}
679+
Input array, expected to have a real-valued data type.
680+
axis : {None, int}, optional
681+
Axis or axes along which values must be computed. If a tuple of unique
682+
integers, values are computed over multiple axes. If ``None``, the
683+
result is computed over the entire array.
684+
Default: ``None``.
685+
dtype : {None, dtype}, optional
686+
Data type of the returned array. If ``None``, the default data type is
687+
inferred from the "kind" of the input array data type.
688+
689+
- If `x` has a real-valued floating-point data type, the returned array
690+
will have the same data type as `x`.
691+
- If `x` has a boolean or integral data type, the returned array will
692+
have the default floating point data type for the device where input
693+
array `x` is allocated.
694+
- If `x` has a complex-valued floating-point data type, an error is
695+
raised.
696+
697+
If the data type (either specified or resolved) differs from the data
698+
type of `x`, the input array elements are cast to the specified data
699+
type before computing the result.
700+
Default: ``None``.
701+
include_initial : {None, bool}, optional
702+
A boolean indicating whether to include the initial value (i.e., the
703+
additive identity, zero) as the first value along the provided axis in
704+
the output.
705+
Default: ``False``.
706+
out : {None, dpnp.ndarray, usm_ndarray}, optional
707+
The array into which the result is written. The data type of `out` must
708+
match the expected shape and the expected data type of the result or
709+
(if provided) `dtype`. If ``None`` then a new array is returned.
710+
Default: ``None``.
711+
712+
Returns
713+
-------
714+
out : dpnp.ndarray
715+
An array containing the results. If the result was computed over the
716+
entire array, a zero-dimensional array is returned. The returned array
717+
has the data type as described in the `dtype` parameter description
718+
above.
719+
720+
Note
721+
----
722+
This function is equivalent of `numpy.logaddexp.accumulate`.
723+
724+
See Also
725+
--------
726+
:obj:`dpnp.logsumexp` : Logarithm of the sum of elements of the inputs,
727+
element-wise.
728+
729+
Examples
730+
--------
731+
>>> import dpnp as np
732+
>>> a = np.ones(10)
733+
>>> np.cumlogsumexp(a)
734+
array([1. , 1.69314718, 2.09861229, 2.38629436, 2.60943791,
735+
2.79175947, 2.94591015, 3.07944154, 3.19722458, 3.30258509])
736+
737+
"""
738+
739+
dpnp.check_supported_arrays_type(x)
740+
if x.ndim > 1 and axis is None:
741+
usm_x = dpnp.ravel(x).get_array()
742+
else:
743+
usm_x = dpnp.get_usm_ndarray(x)
744+
745+
return dpnp_wrap_reduction_call(
746+
x,
747+
out,
748+
dpt.cumulative_logsumexp,
749+
_get_accumulation_res_dt,
750+
usm_x,
751+
axis=axis,
752+
dtype=dtype,
753+
include_initial=include_initial,
754+
)
755+
756+
668757
def deg2rad(x1):
669758
"""
670759
Convert angles from degrees to radians.

tests/test_mathematical.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,114 @@ def test_not_implemented_kwargs(self, kwargs):
161161
dpnp.clip(a, 1, 5, **kwargs)
162162

163163

164+
class TestCumLogSumExp:
165+
def _assert_arrays(self, res, exp, axis, include_initial):
166+
if include_initial:
167+
if axis != None:
168+
res_initial = dpnp.take(res, dpnp.array([0]), axis=axis)
169+
res_no_initial = dpnp.take(
170+
res, dpnp.array(range(1, res.shape[axis])), axis=axis
171+
)
172+
else:
173+
res_initial = res[0]
174+
res_no_initial = res[1:]
175+
assert_dtype_allclose(res_no_initial, exp)
176+
assert (res_initial == -dpnp.inf).all()
177+
else:
178+
assert_dtype_allclose(res, exp)
179+
180+
def _get_exp_array(self, a, axis, dtype):
181+
np_a = dpnp.asnumpy(a)
182+
if axis != None:
183+
return numpy.logaddexp.accumulate(np_a, axis=axis, dtype=dtype)
184+
return numpy.logaddexp.accumulate(np_a.ravel(), dtype=dtype)
185+
186+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
187+
@pytest.mark.parametrize("axis", [None, 2, -1])
188+
@pytest.mark.parametrize("include_initial", [True, False])
189+
def test_basic(self, dtype, axis, include_initial):
190+
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
191+
res = dpnp.cumlogsumexp(a, axis=axis, include_initial=include_initial)
192+
193+
exp_dt = None
194+
if dtype == dpnp.bool:
195+
exp_dt = dpnp.default_float_type(a.device)
196+
197+
exp = self._get_exp_array(a, axis, exp_dt)
198+
self._assert_arrays(res, exp, axis, include_initial)
199+
200+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
201+
@pytest.mark.parametrize("axis", [None, 2, -1])
202+
@pytest.mark.parametrize("include_initial", [True, False])
203+
def test_out(self, dtype, axis, include_initial):
204+
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
205+
206+
if dpnp.issubdtype(a, dpnp.float32):
207+
exp_dt = dpnp.float32
208+
else:
209+
exp_dt = dpnp.default_float_type(a.device)
210+
211+
if axis != None:
212+
if include_initial:
213+
norm_axis = numpy.core.numeric.normalize_axis_index(
214+
axis, a.ndim, "axis"
215+
)
216+
out_sh = (
217+
a.shape[:norm_axis]
218+
+ (a.shape[norm_axis] + 1,)
219+
+ a.shape[norm_axis + 1 :]
220+
)
221+
else:
222+
out_sh = a.shape
223+
else:
224+
out_sh = (a.size + int(include_initial),)
225+
out = dpnp.empty_like(a, shape=out_sh, dtype=exp_dt)
226+
res = dpnp.cumlogsumexp(
227+
a, axis=axis, include_initial=include_initial, out=out
228+
)
229+
230+
exp = self._get_exp_array(a, axis, exp_dt)
231+
232+
assert res is out
233+
self._assert_arrays(res, exp, axis, include_initial)
234+
235+
def test_axis_tuple(self):
236+
a = dpnp.ones((3, 4))
237+
assert_raises(TypeError, dpnp.cumlogsumexp, a, axis=(0, 1))
238+
239+
@pytest.mark.parametrize(
240+
"in_dtype", get_all_dtypes(no_bool=True, no_complex=True)
241+
)
242+
@pytest.mark.parametrize("out_dtype", get_all_dtypes(no_bool=True))
243+
def test_dtype(self, in_dtype, out_dtype):
244+
a = dpnp.ones(100, dtype=in_dtype)
245+
res = dpnp.cumlogsumexp(a, dtype=out_dtype)
246+
exp = numpy.logaddexp.accumulate(dpnp.asnumpy(a))
247+
exp = exp.astype(out_dtype)
248+
249+
assert_allclose(res, exp, rtol=1e-06)
250+
251+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
252+
@pytest.mark.parametrize(
253+
"arr_dt", get_all_dtypes(no_none=True, no_complex=True)
254+
)
255+
@pytest.mark.parametrize(
256+
"out_dt", get_all_dtypes(no_none=True, no_complex=True)
257+
)
258+
@pytest.mark.parametrize("dtype", get_all_dtypes())
259+
def test_out_dtype(self, arr_dt, out_dt, dtype):
260+
a = numpy.arange(10, 20).reshape((2, 5)).astype(dtype=arr_dt)
261+
out = numpy.zeros_like(a, dtype=out_dt)
262+
263+
ia = dpnp.array(a)
264+
iout = dpnp.array(out)
265+
266+
result = dpnp.cumlogsumexp(ia, out=iout, dtype=dtype, axis=1)
267+
exp = numpy.logaddexp.accumulate(a, out=out, axis=1)
268+
assert_allclose(result, exp.astype(dtype), rtol=1e-06)
269+
assert result is iout
270+
271+
164272
class TestCumProd:
165273
@pytest.mark.parametrize(
166274
"arr, axis",

tests/test_strides.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,16 @@ def test_logsumexp(dtype):
125125
assert_allclose(result, expected)
126126

127127

128+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
129+
def test_cumlogsumexp(dtype):
130+
a = numpy.arange(10, dtype=dtype)[::2]
131+
dpa = dpnp.arange(10, dtype=dtype)[::2]
132+
133+
result = dpnp.cumlogsumexp(dpa)
134+
expected = numpy.logaddexp.accumulate(a)
135+
assert_allclose(result, expected)
136+
137+
128138
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
129139
def test_reduce_hypot(dtype):
130140
a = numpy.arange(10, dtype=dtype)[::2]

tests/test_sycl_queue.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,22 @@ def test_logsumexp(device):
556556
assert_sycl_queue_equal(result_queue, expected_queue)
557557

558558

559+
@pytest.mark.parametrize(
560+
"device",
561+
valid_devices,
562+
ids=[device.filter_string for device in valid_devices],
563+
)
564+
def test_cumlogsumexp(device):
565+
x = dpnp.arange(10, device=device)
566+
result = dpnp.cumlogsumexp(x)
567+
expected = numpy.logaddexp.accumulate(x.asnumpy())
568+
assert_dtype_allclose(result, expected)
569+
570+
expected_queue = x.get_array().sycl_queue
571+
result_queue = result.get_array().sycl_queue
572+
assert_sycl_queue_equal(result_queue, expected_queue)
573+
574+
559575
@pytest.mark.parametrize(
560576
"device",
561577
valid_devices,

tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def test_norm(usm_type, ord, axis):
522522
),
523523
pytest.param("cosh", [-5.0, -3.5, 0.0, 3.5, 5.0]),
524524
pytest.param("count_nonzero", [0, 1, 7, 0]),
525+
pytest.param("cumlogsumexp", [1.0, 2.0, 4.0, 7.0]),
525526
pytest.param("cumprod", [[1, 2, 3], [4, 5, 6]]),
526527
pytest.param("cumsum", [[1, 2, 3], [4, 5, 6]]),
527528
pytest.param("diagonal", [[[1, 2], [3, 4]]]),

0 commit comments

Comments
 (0)