No error raised by jitting class methods with backend agnostic architecture #19143
-
Hi, I am writing a code using backend-agnostic architecture so that I can run my existing code in both Numpy and Jax. The code is something like this: import numpy as np
import jax.numpy as jnp
from jax import jit
import jax
class NumPyBackend:
def sin(self, x):
return np.sin(x)
def cos(self, x):
return np.cos(x)
def jit(self, f):
return f
class JaxBackend:
def sin(self, x):
return jnp.sin(x)
def cos(self, x):
return jnp.cos(x)
def jit(self, f):
return jit(f)
BACKEND = NumPyBackend()
class CustomClass:
def __init__(self, x):
self.x = x
@BACKEND.jit
def calc(self):
if jnp.isclose(BACKEND.sin(self.x)**2 + BACKEND.cos(self.x)**2, 1):
return BACKEND.sin(self.x)**2 + BACKEND.cos(self.x)**2
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
from jax import tree_util
tree_util.register_pytree_node(CustomClass,
CustomClass._tree_flatten,
CustomClass._tree_unflatten)
a = 1.0
BACKEND = NumPyBackend()
a_np = CustomClass(a)
print(a_np.calc())
# 1.0 (No error)
a = jnp.array(a)
BACKEND = JaxBackend()
a_jax = CustomClass(a)
print(a_jax.calc())
# 1.0 (No error)
@jit
def a_ref(x):
if jnp.isclose(BACKEND.sin(x)**2 + BACKEND.cos(x)**2, 1):
return jnp.sin(x)**2 + jnp.cos(x)**2
print(a_ref(a))
# TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
# The error occurred while tracing the function a_ref at /tmp/ipykernel_14657/129821386.py:62 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError I am wondering why Note that the following code will raise import jax
class CustomClass:
def __init__(self, x):
self.x = x
@jax.jit
def norm(self):
if jnp.isclose(jnp.sin(self.x)**2 + jnp.cos(self.x)**2, 1):
return jnp.sin(self.x)**2 + jnp.cos(self.x)**2
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
from jax import tree_util
tree_util.register_pytree_node(CustomClass,
CustomClass._tree_flatten,
CustomClass._tree_unflatten)
a_ref2 = CustomClass(a)
a_ref2.norm()
# TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
# The error occurred while tracing the function norm at /tmp/ipykernel_20399/589877195.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument self[<flat index 0>].
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError OS: Windows WSL2 (Ubuntu) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi - the reason that |
Beta Was this translation helpful? Give feedback.
-
Side note -- note that registering classes as custom PyTrees can be fairly finickity business in JAX. You might like Equinox, which offers an |
Beta Was this translation helpful? Give feedback.
Hi - the reason that
a_jax.calc()
does not error is because thecalc
method ofa_jax
is not JIT-compiled. Class method decorators are evaluated at the time the class is defined, not at the time the class is instantiated, and when the class is defined, the global variableBACKEND
is set toNumPyBackend()
. The fact that this global variable is modified later has no effect on the@BACKEND.jit()
decorator used previously in the class definition.