Skip to content

Commit a7039a2

Browse files
committed
jnp.reshape: raise TypeError when specifying newshape
1 parent 2e0474a commit a7039a2

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,20 +2143,11 @@ def reshape(
21432143
__tracebackhide__ = True
21442144
util.check_arraylike("reshape", a)
21452145

2146-
# TODO(micky774): deprecated 2024-5-9, remove after deprecation expires.
2146+
# TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
21472147
if not isinstance(newshape, DeprecatedArg):
2148-
if shape is not None:
2149-
raise ValueError(
2150-
"jnp.reshape received both `shape` and `newshape` arguments. Note that "
2151-
"using `newshape` is deprecated, please only use `shape` instead."
2152-
)
2153-
deprecations.warn(
2154-
"jax-numpy-reshape-newshape",
2155-
("The newshape argument of jax.numpy.reshape is deprecated. "
2156-
"Please use the shape argument instead."), stacklevel=2)
2157-
shape = newshape
2158-
del newshape
2159-
elif shape is None:
2148+
raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
2149+
" Use shape instead.")
2150+
if shape is None:
21602151
raise TypeError(
21612152
"jnp.shape requires passing a `shape` argument, but none was given."
21622153
)

tests/lax_numpy_test.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,13 +3428,8 @@ def testReshape(self, arg_shape, out_shape, dtype, order):
34283428
self._CompileAndCheck(jnp_fun, args_maker)
34293429

34303430
def testReshapeDeprecatedArgs(self):
3431-
msg = "The newshape argument of jax.numpy.reshape is deprecated."
3432-
def assert_warns_or_errors(msg=msg):
3433-
if deprecations.is_accelerated("jax-numpy-reshape-newshape"):
3434-
return self.assertRaisesRegex(ValueError, msg)
3435-
else:
3436-
return self.assertWarnsRegex(DeprecationWarning, msg)
3437-
with assert_warns_or_errors(msg):
3431+
msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36."
3432+
with self.assertRaisesRegex(TypeError, msg):
34383433
jnp.reshape(jnp.arange(4), newshape=(2, 2))
34393434

34403435
@jtu.sample_product(

0 commit comments

Comments
 (0)