-
Is there a best practice to raise warnings? I want to raise a warning when a runtime value exceeds a certain limit and print its value. Is this in general possible? The best I could come up with so far is the following, but it is displaying the tracer and not its value (same happens without jitting but I guess that is because import jax.numpy as jnp
import jax
import warnings
def warn(x):
warnings.warn(f"variable x is now above 1: x = {x}")
@jax.jit
def func(x):
x = x + 1
jax.lax.cond(x>1, warn, lambda x: None, x)
x = jnp.array(1., dtype=float)
>>> func(x)
/tmp/ipykernel_27944/1070628015.py:2: UserWarning: variable x is now above 1: x = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/1)>
warnings.warn(f"variable x is now above 1: x = {x}") edit: working solution in the comments below accepted answer |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
@Qottmann there's an experimental callback module which allows signaling events back to the Python runtime: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html You can use some of the functions in this module to communicate back to the Python runtime with runtime values on the device, e.g. see |
Beta Was this translation helpful? Give feedback.
-
For completeness, here is a working solution. Feedback for improvements encouraged! import warnings
import jax
import jax.numpy as jnp
from jax.experimental import host_callback
def warn(x, transforms):
if x>1:
warnings.warn(f"variable x is now above 1: x = {x}")
@jax.jit
def func(x):
x = x + 1
host_callback.id_tap(warn, x)
x = jnp.array(1., dtype=float)
>>> func(x)
/tmp/ipykernel_1141/4253178298.py:8: UserWarning: variable x is now above 1: x = 2.0
warnings.warn(f"variable x is now above 1: x = {x}") |
Beta Was this translation helpful? Give feedback.
@Qottmann there's an experimental callback module which allows signaling events back to the Python runtime: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html
You can use some of the functions in this module to communicate back to the Python runtime with runtime values on the device, e.g. see
id_tap
.