Skip to content

Commit db821fa

Browse files
junjiang-labcopybara-github
authored andcommitted
Fix JAX bridge incompatibility with NumPy < 2.0
PiperOrigin-RevId: 823564334
1 parent f74158b commit db821fa

File tree

1 file changed

+5
-1
lines changed
  • ai_edge_torch/odml_torch/jax_bridge

1 file changed

+5
-1
lines changed

ai_edge_torch/odml_torch/jax_bridge/_wrap.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ def lower_wrapper(*args):
7979
nonlocal jax_lower_static_kwargs
8080

8181
jaxfn_args = []
82-
jaxfn_kwargs = jax_lower_static_kwargs.copy()
82+
jaxfn_kwargs = {
83+
k: jax.numpy.array(v) if isinstance(v, (int, float)) else v
84+
for k, v in jax_lower_static_kwargs.items()
85+
}
86+
8387
for name, arg in zip(jax_lower_argnames, args):
8488
if name is None:
8589
jaxfn_args.append(arg)

0 commit comments

Comments
 (0)