Skip to content

Commit 26d7166

Browse files
authored
Handle JAX _DimExpr as dtype. Default arange step to None. (#21688)
The pattern used to get dtypes in ops that can handle scalars is `getattr(x, "dtype", type(x))`. The `type(x)` part is used for native Python types `int` and `float`. But if `x` is derived from a dynamic dimension with JAX, it can be a `_DimExpr`. Changed `arange` to have a default `step` of `None` instead of `1`. NumPy, Torch, JAX and Keras all document that the following versions of `arange` are supported: - `arange(stop)`: generate values from 0 to stop, stepping by 1. - `arange(start, stop)`: generate values from start to stop, stepping by 1. - `arange(start, stop, step)`: generate values from start to stop, stepping by step. Note that in the case of NumPy and Torch, this is achieved via overloads and not by detecting `None` values. Regardless, the form `arange(stop, step=n)` is not officially supported. However, by having a default `step` value of `1`, we were always exercising this `arange(stop, step=1)` case when only `stop` was provided. This is causing issues with JAX is some specific contexts. By changing the default value of `step` to `None`, we're aligning with the JAX approach and we can detect when the value is not provided. - Added support for `arange(stop, step=n)` for all backends with unit tests - Fixed bug with Torch backend where `arange(stop, step=n)` would ignore the `step` value These fix a regression in KerasHub introduced by #21672
1 parent cdc7e29 commit 26d7166

File tree

8 files changed

+50
-33
lines changed

8 files changed

+50
-33
lines changed

keras/src/backend/common/variables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,10 @@ def standardize_shape(shape):
599599

600600
if config.backend() == "jax":
601601
# Replace `_DimExpr` (dimension expression) with None
602+
from jax import export as jax_export
603+
602604
shape = tuple(
603-
[None if "_DimExpr" in str(type(d)) else d for d in shape]
605+
None if jax_export.is_symbolic_dim(d) else d for d in shape
604606
)
605607

606608
if config.backend() == "torch":

keras/src/backend/jax/numpy.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import jax.experimental.sparse as jax_sparse
55
import jax.numpy as jnp
6+
from jax import export as jax_export
67

78
from keras.src.backend import config
89
from keras.src.backend.common import dtypes
@@ -306,14 +307,20 @@ def append(x1, x2, axis=None):
306307
return jnp.append(x1, x2, axis=axis)
307308

308309

309-
def arange(start, stop=None, step=1, dtype=None):
310+
def arange(start, stop=None, step=None, dtype=None):
311+
def get_dtype(x):
312+
if hasattr(x, "dtype"):
313+
return x.dtype
314+
if jax_export.is_symbolic_dim(x):
315+
return int
316+
return type(x)
317+
310318
if dtype is None:
311-
dtypes_to_resolve = [
312-
getattr(start, "dtype", type(start)),
313-
getattr(step, "dtype", type(step)),
314-
]
319+
dtypes_to_resolve = [get_dtype(start)]
315320
if stop is not None:
316-
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
321+
dtypes_to_resolve.append(get_dtype(stop))
322+
if step is not None:
323+
dtypes_to_resolve.append(get_dtype(step))
317324
dtype = dtypes.result_type(*dtypes_to_resolve)
318325
dtype = standardize_dtype(dtype)
319326
return jnp.arange(start, stop, step=step, dtype=dtype)

keras/src/backend/numpy/numpy.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,18 @@ def append(x1, x2, axis=None):
173173
return np.append(x1, x2, axis=axis)
174174

175175

176-
def arange(start, stop=None, step=1, dtype=None):
176+
def arange(start, stop=None, step=None, dtype=None):
177177
if dtype is None:
178-
dtypes_to_resolve = [
179-
getattr(start, "dtype", type(start)),
180-
getattr(step, "dtype", type(step)),
181-
]
178+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
182179
if stop is not None:
183180
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
181+
if step is not None:
182+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
184183
dtype = dtypes.result_type(*dtypes_to_resolve)
184+
if stop is None:
185+
start, stop = 0, start
186+
if step is None:
187+
step = 1
185188
return np.arange(start, stop, step=step, dtype=dtype)
186189

187190

keras/src/backend/openvino/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def append(x1, x2, axis=None):
210210
return OpenVINOKerasTensor(ov_opset.concat([x1, x2], axis).output(0))
211211

212212

213-
def arange(start, stop=None, step=1, dtype=None):
213+
def arange(start, stop=None, step=None, dtype=None):
214214
if stop is None:
215215
start, stop = get_ov_output(0), get_ov_output(start)
216216
else:

keras/src/backend/tensorflow/numpy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -813,16 +813,17 @@ def append(x1, x2, axis=None):
813813
return tf.concat([x1, x2], axis=axis)
814814

815815

816-
def arange(start, stop=None, step=1, dtype=None):
816+
def arange(start, stop=None, step=None, dtype=None):
817817
if dtype is None:
818-
dtypes_to_resolve = [
819-
getattr(start, "dtype", type(start)),
820-
getattr(step, "dtype", type(step)),
821-
]
818+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
822819
if stop is not None:
823820
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
821+
if step is not None:
822+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
824823
dtype = dtypes.result_type(*dtypes_to_resolve)
825824
dtype = standardize_dtype(dtype)
825+
if step is None:
826+
step = 1
826827
try:
827828
out = tf.range(start, stop, delta=step, dtype=dtype)
828829
except tf.errors.NotFoundError:

keras/src/backend/torch/numpy.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,18 +313,19 @@ def append(x1, x2, axis=None):
313313
return torch.cat((x1, x2), dim=axis)
314314

315315

316-
def arange(start, stop=None, step=1, dtype=None):
316+
def arange(start, stop=None, step=None, dtype=None):
317317
if dtype is None:
318-
dtypes_to_resolve = [
319-
getattr(start, "dtype", type(start)),
320-
getattr(step, "dtype", type(step)),
321-
]
318+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
322319
if stop is not None:
323320
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
321+
if step is not None:
322+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
324323
dtype = dtypes.result_type(*dtypes_to_resolve)
325324
dtype = to_torch_dtype(dtype)
326325
if stop is None:
327-
return torch.arange(end=start, dtype=dtype, device=get_device())
326+
start, stop = 0, start
327+
if step is None:
328+
step = 1
328329
return torch.arange(
329330
start, stop, step=step, dtype=dtype, device=get_device()
330331
)

keras/src/ops/numpy.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -595,27 +595,28 @@ def __init__(self, dtype=None, *, name=None):
595595
super().__init__(name=name)
596596
self.dtype = None if dtype is None else backend.standardize_dtype(dtype)
597597

598-
def call(self, start, stop=None, step=1):
598+
def call(self, start, stop=None, step=None):
599599
return backend.numpy.arange(start, stop, step=step, dtype=self.dtype)
600600

601-
def compute_output_spec(self, start, stop=None, step=1):
601+
def compute_output_spec(self, start, stop=None, step=None):
602602
if stop is None:
603603
start, stop = 0, start
604+
if step is None:
605+
step = 1
604606
output_shape = [int(np.ceil((stop - start) / step))]
605607
dtype = self.dtype
606608
if dtype is None:
607-
dtypes_to_resolve = [
608-
getattr(start, "dtype", type(start)),
609-
getattr(step, "dtype", type(step)),
610-
]
609+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
611610
if stop is not None:
612611
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
612+
if step is not None:
613+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
613614
dtype = dtypes.result_type(*dtypes_to_resolve)
614615
return KerasTensor(output_shape, dtype=dtype)
615616

616617

617618
@keras_export(["keras.ops.arange", "keras.ops.numpy.arange"])
618-
def arange(start, stop=None, step=1, dtype=None):
619+
def arange(start, stop=None, step=None, dtype=None):
619620
"""Return evenly spaced values within a given interval.
620621
621622
`arange` can be called with a varying number of positional arguments:

keras/src/ops/numpy_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6357,8 +6357,10 @@ def test_argsort(self, dtype):
63576357
)
63586358

63596359
@parameterized.parameters(
6360-
(10, None, 1, None),
6361-
(0, 10, 1, None),
6360+
(10, None, None, None), # stop
6361+
(2, 10, None, None), # start, stop
6362+
(10, None, 2, None), # stop, step
6363+
(0, 10, 2, None), # start, stop, step
63626364
(0, 10, 0.5, None),
63636365
(10.0, None, 1, None),
63646366
(0, 10.0, 1, None),

0 commit comments

Comments
 (0)