Skip to content

Commit 783285a

Browse files
yashk2810Google-ML-Automation
authored andcommitted
FIx jax2tf breakge of iota
PiperOrigin-RevId: 688146581
1 parent f833891 commit 783285a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/experimental/jax2tf/jax2tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,7 @@ def _conj(x, **kwargs):
17581758
tf_impl[lax.mul_p] = tf.math.multiply
17591759

17601760

1761-
def _iota(*, dtype, shape, dimension, sharding):
1761+
def _iota(*, dtype, shape, dimension, sharding=None):
17621762
dtype = _to_tf_dtype(dtype)
17631763
# Some dtypes are unsupported, like uint32, so we just fall back to int32.
17641764
# TODO(mattjj, necula): improve tf.range dtype handling

0 commit comments

Comments
 (0)