You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Handle JAX _DimExpr as dtype. Default arangestep to None. (#21688)
The pattern used to get dtypes in ops that can handle scalars is `getattr(x, "dtype", type(x))`. The `type(x)` part is used for native Python types `int` and `float`. But if `x` is derived from a dynamic dimension with JAX, it can be a `_DimExpr`.
Changed `arange` to have a default `step` of `None` instead of `1`.
NumPy, Torch, JAX and Keras all document that the following versions of `arange` are supported:
- `arange(stop)`: generate values from 0 to stop, stepping by 1.
- `arange(start, stop)`: generate values from start to stop, stepping by 1.
- `arange(start, stop, step)`: generate values from start to stop, stepping by step.
Note that in the case of NumPy and Torch, this is achieved via overloads and not by detecting `None` values.
Regardless, the form `arange(stop, step=n)` is not officially supported. However, by having a default `step` value of `1`, we were always exercising this `arange(stop, step=1)` case when only `stop` was provided. This is causing issues with JAX is some specific contexts. By changing the default value of `step` to `None`, we're aligning with the JAX approach and we can detect when the value is not provided.
- Added support for `arange(stop, step=n)` for all backends with unit tests
- Fixed bug with Torch backend where `arange(stop, step=n)` would ignore the `step` value
These fix a regression in KerasHub introduced by #21672
0 commit comments