We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f74158b commit db821faCopy full SHA for db821fa
ai_edge_torch/odml_torch/jax_bridge/_wrap.py
@@ -79,7 +79,11 @@ def lower_wrapper(*args):
79
nonlocal jax_lower_static_kwargs
80
81
jaxfn_args = []
82
- jaxfn_kwargs = jax_lower_static_kwargs.copy()
+ jaxfn_kwargs = {
83
+ k: jax.numpy.array(v) if isinstance(v, (int, float)) else v
84
+ for k, v in jax_lower_static_kwargs.items()
85
+ }
86
+
87
for name, arg in zip(jax_lower_argnames, args):
88
if name is None:
89
jaxfn_args.append(arg)
0 commit comments