diff --git a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py index bdf5c63d..e006d800 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py +++ b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py @@ -79,7 +79,13 @@ def lower_wrapper(*args): nonlocal jax_lower_static_kwargs jaxfn_args = [] - jaxfn_kwargs = jax_lower_static_kwargs.copy() + # TODO(junjiang): revert to jax_lower_static_kwargs.copy() once NumPy 2.0 is + # the minimum supported version. + jaxfn_kwargs = { + k: jax.numpy.array(v) if isinstance(v, float) else v + for k, v in jax_lower_static_kwargs.items() + } + for name, arg in zip(jax_lower_argnames, args): if name is None: jaxfn_args.append(arg)