You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found that jax.experimental.jax2tf.call_tf with JIT is not working with Cloud TPU and colab notebooks.
This small example derived from README fails with ValueError:
I confirmed that this sample runs smoothly with CPU and GPU runtimes.
I know jax2tf is "experimental", but it is quite convenient so I want to use it with TPUs.
Here's the stack trace:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, code_gen_optional)
319 func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
--> 320 stage="hlo_serialized", device_name=tf_device_name)
321 except Exception as e:
32 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in compiler_ir_generator(stage, device_name)
1068 function_name=fn_name,
-> 1069 args=list(filtered_flat_args) + concrete_fn.captured_inputs)
1070 if stage in ("hlo_serialized", "optimized_hlo_serialized",
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/context.py in get_compiler_ir(self, device_name, function_name, args, stage)
1637 return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
-> 1638 stage, device_name, args)
1639
ValueError: No matching device found for '/device:TPU:0'
The above exception was the direct cause of the following exception:
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-16-974d2d918d16> in <module>()
2 jitted_f = jax.jit(cos_tf_sin_jax)
----> 3 output = jitted_f(x)
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
418 device=device, backend=backend, name=flat_fun.__name__,
--> 419 donated_invars=donated_invars, inline=inline)
420 out_pytree_def = out_tree()
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
1631 def bind(self, fun, *args, **params):
-> 1632 return call_bind(self, fun, *args, **params)
1633
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1622 tracers = map(top_trace.full_raise, args)
-> 1623 outs = primitive.process(top_trace, fun, tracers, params)
1624 return map(full_lower, apply_todos(env_trace_todo(), outs))
/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
1634 def process(self, trace, fun, tracers, params):
-> 1635 return trace.process_call(self, fun, tracers, params)
1636
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
626 def process_call(self, primitive, f, tracers, params):
--> 627 return primitive.impl(f, *tracers, **params)
628 process_map = process_call
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
687 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 688 *unsafe_map(arg_spec, args))
689 try:
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
262 else:
--> 263 ans = call(fun, *args)
264 cache[key] = (ans, fun.stores)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
759 return lower_xla_callable(fun, device, backend, name, donated_invars,
--> 760 *arg_specs).compile().unsafe_call
761
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
771 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
--> 772 fun, abstract_args, pe.debug_info_final(fun, "jit"))
773 if any(isinstance(c, core.Tracer) for c in consts):
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, debug_info)
1541 with core.new_sublevel():
-> 1542 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1543 del fun, main
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1519 in_tracers = map(trace.new_arg, in_avals)
-> 1520 ans = fun.call_wrapped(*in_tracers)
1521 out_tracers = map(trace.full_raise, ans)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
<ipython-input-7-fb2d1d517a65> in cos_tf_sin_jax(x)
5 def cos_tf_sin_jax(x):
----> 6 return jax.numpy.sin(jax.experimental.jax2tf.call_tf(cos_tf)(x))
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/custom_derivatives.py in __call__(self, *args, **kwargs)
525 out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
--> 526 out_trees=out_trees)
527 fst, aux = lu.merge_linear_aux(out_tree, out_trees)
/usr/local/lib/python3.7/dist-packages/jax/_src/custom_derivatives.py in bind(self, fun, fwd, bwd, out_trees, *args)
608 outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
--> 609 out_trees=out_trees)
610 _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees)
1408 with core.new_sublevel():
-> 1409 fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
1410 closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1519 in_tracers = map(trace.new_arg, in_avals)
-> 1520 ans = fun.call_wrapped(*in_tracers)
1521 out_tracers = map(trace.full_raise, ans)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in make_call(*args_jax)
131 function_flat_tf=function_flat_tf,
--> 132 args_flat_sig_tf=args_flat_sig_tf)
133 return res_treedef.unflatten(res_jax_flat)
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, *args, **params)
271 tracers = map(top_trace.full_raise, args)
--> 272 out = top_trace.process_primitive(self, tracers, params)
273 return map(full_lower, out) if self.multiple_results else full_lower(out)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
1316 avals = [t.aval for t in tracers]
-> 1317 out_avals = primitive.abstract_eval(*avals, **params)
1318 out_avals = [out_avals] if not primitive.multiple_results else out_avals
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in _call_tf_abstract_eval(function_flat_tf, args_flat_sig_tf, *_, **__)
233 _, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
--> 234 code_gen_optional=True)
235 return tuple(result_avals)
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, code_gen_optional)
345 "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
--> 346 raise ValueError(msg) from e
347
UnfilteredStackTrace: ValueError: Error compiling TensorFlow function. call_tf can used in a staged context (under jax.jit, lax.scan, etc.) only with compileable functions. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-16-974d2d918d16> in <module>()
1 x = jnp.float32(1.)
2 jitted_f = jax.jit(cos_tf_sin_jax)
----> 3 output = jitted_f(x)
<ipython-input-7-fb2d1d517a65> in cos_tf_sin_jax(x)
4 # Compute cos with TF and sin with JAX
5 def cos_tf_sin_jax(x):
----> 6 return jax.numpy.sin(jax.experimental.jax2tf.call_tf(cos_tf)(x))
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in make_call(*args_jax)
130 callable_flat_tf=callable_flat_tf,
131 function_flat_tf=function_flat_tf,
--> 132 args_flat_sig_tf=args_flat_sig_tf)
133 return res_treedef.unflatten(res_jax_flat)
134
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in _call_tf_abstract_eval(function_flat_tf, args_flat_sig_tf, *_, **__)
232 # full compilation only to get the abstract avals.
233 _, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
--> 234 code_gen_optional=True)
235 return tuple(result_avals)
236
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/call_tf.py in _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, code_gen_optional)
344 "compileable functions. " +
345 "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
--> 346 raise ValueError(msg) from e
347
348 xla_comp = xla_client.XlaComputation(func_tf_hlo)
ValueError: Error compiling TensorFlow function. call_tf can used in a staged context (under jax.jit, lax.scan, etc.) only with compileable functions. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jax-developers,
I found that
jax.experimental.jax2tf.call_tf
with JIT is not working with Cloud TPU and colab notebooks.This small example derived from README fails with ValueError:
I confirmed that this sample runs smoothly with CPU and GPU runtimes.
I know jax2tf is "experimental", but it is quite convenient so I want to use it with TPUs.
Here's the stack trace:
Beta Was this translation helpful? Give feedback.
All reactions