Skip to content

Commit 3d83143

Browse files
authored
Merge pull request numpy#19356 from mhvk/functionbase-vectorize-refactor
API: Ensure np.vectorize outputs can be subclasses.
2 parents b63e256 + 3c892cd commit 3d83143

File tree

3 files changed

+53
-21
lines changed

3 files changed

+53
-21
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
`numpy.vectorize` functions now produce the same output class as the base function
2+
----------------------------------------------------------------------------------
3+
When a function that respects `numpy.ndarray` subclasses is vectorized using
4+
`numpy.vectorize`, the vectorized function will now be subclass-safe
5+
also for cases that a signature is given (i.e., when creating a ``gufunc``):
6+
the output class will be the same as that returned by the first call to
7+
the underlying function.

numpy/lib/function_base.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,7 +1522,7 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
15221522
p : array_like
15231523
Input array.
15241524
discont : float, optional
1525-
Maximum discontinuity between values, default is ``period/2``.
1525+
Maximum discontinuity between values, default is ``period/2``.
15261526
Values below ``period/2`` are treated as if they were ``period/2``.
15271527
To have an effect different from the default, `discont` should be
15281528
larger than ``period/2``.
@@ -1531,7 +1531,7 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
15311531
period: float, optional
15321532
Size of the range over which the input wraps. By default, it is
15331533
``2 pi``.
1534-
1534+
15351535
.. versionadded:: 1.21.0
15361536
15371537
Returns
@@ -1545,8 +1545,8 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
15451545
15461546
Notes
15471547
-----
1548-
If the discontinuity in `p` is smaller than ``period/2``,
1549-
but larger than `discont`, no unwrapping is done because taking
1548+
If the discontinuity in `p` is smaller than ``period/2``,
1549+
but larger than `discont`, no unwrapping is done because taking
15501550
the complement would only make the discontinuity larger.
15511551
15521552
Examples
@@ -1579,7 +1579,7 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
15791579
slice1 = tuple(slice1)
15801580
dtype = np.result_type(dd, period)
15811581
if _nx.issubdtype(dtype, _nx.integer):
1582-
interval_high, rem = divmod(period, 2)
1582+
interval_high, rem = divmod(period, 2)
15831583
boundary_ambiguous = rem == 0
15841584
else:
15851585
interval_high = period / 2
@@ -1943,11 +1943,19 @@ def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
19431943
for core_dims in list_of_core_dims]
19441944

19451945

1946-
def _create_arrays(broadcast_shape, dim_sizes, list_of_core_dims, dtypes):
1946+
def _create_arrays(broadcast_shape, dim_sizes, list_of_core_dims, dtypes,
1947+
results=None):
19471948
"""Helper for creating output arrays in vectorize."""
19481949
shapes = _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims)
1949-
arrays = tuple(np.empty(shape, dtype=dtype)
1950-
for shape, dtype in zip(shapes, dtypes))
1950+
if dtypes is None:
1951+
dtypes = [None] * len(shapes)
1952+
if results is None:
1953+
arrays = tuple(np.empty(shape=shape, dtype=dtype)
1954+
for shape, dtype in zip(shapes, dtypes))
1955+
else:
1956+
arrays = tuple(np.empty_like(result, shape=shape, dtype=dtype)
1957+
for result, shape, dtype
1958+
in zip(results, shapes, dtypes))
19511959
return arrays
19521960

19531961

@@ -2293,11 +2301,8 @@ def _vectorize_call_with_signature(self, func, args):
22932301
for result, core_dims in zip(results, output_core_dims):
22942302
_update_dim_sizes(dim_sizes, result, core_dims)
22952303

2296-
if otypes is None:
2297-
otypes = [asarray(result).dtype for result in results]
2298-
22992304
outputs = _create_arrays(broadcast_shape, dim_sizes,
2300-
output_core_dims, otypes)
2305+
output_core_dims, otypes, results)
23012306

23022307
for output, result in zip(outputs, results):
23032308
output[index] = result
@@ -4136,13 +4141,13 @@ def trapz(y, x=None, dx=1.0, axis=-1):
41364141
41374142
If `x` is provided, the integration happens in sequence along its
41384143
elements - they are not sorted.
4139-
4144+
41404145
Integrate `y` (`x`) along each 1d slice on the given axis, compute
41414146
:math:`\int y(x) dx`.
41424147
When `x` is specified, this integrates along the parametric curve,
41434148
computing :math:`\int_t y(t) dt =
41444149
\int_t y(t) \left.\frac{dx}{dt}\right|_{x=x(t)} dt`.
4145-
4150+
41464151
Parameters
41474152
----------
41484153
y : array_like
@@ -4163,7 +4168,7 @@ def trapz(y, x=None, dx=1.0, axis=-1):
41634168
a single axis by the trapezoidal rule. If 'y' is a 1-dimensional array,
41644169
then the result is a float. If 'n' is greater than 1, then the result
41654170
is an 'n-1' dimensional array.
4166-
4171+
41674172
See Also
41684173
--------
41694174
sum, cumsum
@@ -4192,16 +4197,16 @@ def trapz(y, x=None, dx=1.0, axis=-1):
41924197
8.0
41934198
>>> np.trapz([1,2,3], dx=2)
41944199
8.0
4195-
4200+
41964201
Using a decreasing `x` corresponds to integrating in reverse:
4197-
4198-
>>> np.trapz([1,2,3], x=[8,6,4])
4202+
4203+
>>> np.trapz([1,2,3], x=[8,6,4])
41994204
-8.0
4200-
4205+
42014206
More generally `x` is used to integrate along a parametric curve.
42024207
This finds the area of a circle, noting we repeat the sample which closes
42034208
the curve:
4204-
4209+
42054210
>>> theta = np.linspace(0, 2 * np.pi, num=1000, endpoint=True)
42064211
>>> np.trapz(np.cos(theta), x=np.sin(theta))
42074212
3.141571941375841

numpy/lib/tests/test_function_base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,26 @@ def test_size_zero_output(self):
16651665
with assert_raises_regex(ValueError, 'new output dimensions'):
16661666
f(x)
16671667

1668+
def test_subclasses(self):
1669+
class subclass(np.ndarray):
1670+
pass
1671+
1672+
m = np.array([[1., 0., 0.],
1673+
[0., 0., 1.],
1674+
[0., 1., 0.]]).view(subclass)
1675+
v = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).view(subclass)
1676+
# generalized (gufunc)
1677+
matvec = np.vectorize(np.matmul, signature='(m,m),(m)->(m)')
1678+
r = matvec(m, v)
1679+
assert_equal(type(r), subclass)
1680+
assert_equal(r, [[1., 3., 2.], [4., 6., 5.], [7., 9., 8.]])
1681+
1682+
# element-wise (ufunc)
1683+
mult = np.vectorize(lambda x, y: x*y)
1684+
r = mult(m, v)
1685+
assert_equal(type(r), subclass)
1686+
assert_equal(r, m * v)
1687+
16681688

16691689
class TestLeaks:
16701690
class A:
@@ -1798,7 +1818,7 @@ def test_simple(self):
17981818
assert_array_equal(unwrap([1, 1 + 2 * np.pi]), [1, 1])
17991819
# check that unwrap maintains continuity
18001820
assert_(np.all(diff(unwrap(rand(10) * 100)) < np.pi))
1801-
1821+
18021822
def test_period(self):
18031823
# check that unwrap removes jumps greater that 255
18041824
assert_array_equal(unwrap([1, 1 + 256], period=255), [1, 2])

0 commit comments

Comments
 (0)