Skip to content

Commit 4ffe88d

Browse files
BUG: integrate.trapezoid: fix broadcasting issue (scipy#21912)
* BUG: fix gh21908 * MAINT: broadcast n-D x to y * STY * TST: fix test * Apply suggestions from code review Co-authored-by: Lucas Colley <[email protected]> * TST: np --> xp --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 1b7ab3b commit 4ffe88d

File tree

2 files changed

+37
-13
lines changed

2 files changed

+37
-13
lines changed

scipy/integrate/_quadrature.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,16 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
136136
d = dx
137137
else:
138138
x = _asarray(x, xp=xp, subok=True)
139-
# reshape to correct shape
140-
shape = [1] * y.ndim
141-
shape[axis] = y.shape[axis]
142-
x = xp.reshape(x, shape)
143-
d = x[tuple(slice1)] - x[tuple(slice2)]
139+
if x.ndim == 1:
140+
d = x[1:] - x[:-1]
141+
# make d broadcastable to y
142+
slice3 = [None] * nd
143+
slice3[axis] = slice(None)
144+
d = d[tuple(slice3)]
145+
else:
146+
# if x is n-D it should be broadcastable to y
147+
x = xp.broadcast_to(x, y.shape)
148+
d = x[tuple(slice1)] - x[tuple(slice2)]
144149
try:
145150
ret = xp.sum(
146151
d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0,

scipy/integrate/tests/test_quadrature.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,14 +309,6 @@ def test_ndim(self, xp):
309309
r = trapezoid(q, x=z[None, None,:], axis=2)
310310
xp_assert_close(r, qz)
311311

312-
# n-d `x` but not the same as `y`
313-
r = trapezoid(q, x=xp.reshape(x[:, None, None], (3, 1)), axis=0)
314-
xp_assert_close(r, qx)
315-
r = trapezoid(q, x=xp.reshape(y[None,:, None], (8, 1)), axis=1)
316-
xp_assert_close(r, qy)
317-
r = trapezoid(q, x=xp.reshape(z[None, None,:], (13, 1)), axis=2)
318-
xp_assert_close(r, qz)
319-
320312
# 1-d `x`
321313
r = trapezoid(q, x=x, axis=0)
322314
xp_assert_close(r, qx)
@@ -325,6 +317,33 @@ def test_ndim(self, xp):
325317
r = trapezoid(q, x=z, axis=2)
326318
xp_assert_close(r, qz)
327319

320+
@skip_xp_backends('jax.numpy',
321+
reasons=["JAX arrays do not support item assignment"])
322+
@pytest.mark.usefixtures("skip_xp_backends")
323+
def test_gh21908(self, xp):
324+
# extended testing for n-dim arrays
325+
x = xp.reshape(xp.linspace(0, 29, 30), (3, 10))
326+
y = xp.reshape(xp.linspace(0, 29, 30), (3, 10))
327+
328+
out0 = xp.linspace(200, 380, 10)
329+
xp_assert_close(trapezoid(y, x=x, axis=0), out0)
330+
xp_assert_close(trapezoid(y, x=xp.asarray([0, 10., 20.]), axis=0), out0)
331+
# x needs to be broadcastable against y
332+
xp_assert_close(
333+
trapezoid(y, x=xp.asarray([0, 10., 20.])[:, None], axis=0),
334+
out0
335+
)
336+
with pytest.raises(Exception):
337+
# x is not broadcastable against y
338+
trapezoid(y, x=xp.asarray([0, 10., 20.])[None, :], axis=0)
339+
340+
out1 = xp.asarray([ 40.5, 130.5, 220.5])
341+
xp_assert_close(trapezoid(y, x=x, axis=1), out1)
342+
xp_assert_close(
343+
trapezoid(y, x=xp.linspace(0, 9, 10), axis=1),
344+
out1
345+
)
346+
328347
@skip_xp_invalid_arg
329348
def test_masked(self, xp):
330349
# Testing that masked arrays behave as if the function is 0 where

0 commit comments

Comments
 (0)