crash with a vmap with diagnostic related to jit #14517
Answered
by
soraros
jecampagne
asked this question in
Q&A
-
Hello, I have an example of a schema that creates a problem (the jax version is the one of the current Colab: 0.3.25). from jax.tree_util import register_pytree_node_class
# Define a base class with tree_flatten/tree_unflatten and other common methods
class GSObject:
def __init__(self, *, gsparams=None, **params):
self._params = params # Dictionary containing all traced parameters
self._gsparams = gsparams # Non-traced static parameters
@property
def flux(self):
"""The flux of the profile."""
return self._params["flux"]
@property
def gsparams(self):
"""A `GSParams` object that sets various parameters relevant for speed/accuracy trade-offs."""
return self._gsparams
@property
def params(self):
"""A Dictionary object containing all parameters of the internal represention of this object."""
return self._params
def tree_flatten(self):
"""This function flattens the GSObject into a list of children
nodes that will be traced by JAX and auxiliary static data."""
print("JEC DBG: gsobject tree_flatten")
# Define the children nodes of the PyTree that need tracing
children = (self.params,)
# Define auxiliary static data that doesn’t need to be traced
aux_data = {"gsparams": self.gsparams}
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
print("JEC DBG: gsobject tree_unflatten")
return cls(**(children[0]), **aux_data)
# a derived call with the decorator @register_pytree_node_class
@register_pytree_node_class
class A(GSObject):
def __init__(self,sigma=None, flux=1.0, gsparams=None):
super().__init__(sigma=sigma, flux=flux, gsparams=gsparams)
##self._bug = self.sigma**2
def __hash__(self):
return hash(("A", self.sigma, self.flux, self.gsparams))
def __repr__(self):
return "A(sigma=%r, flux=%r, gsparams=%r)" % (
self.sigma,
self.flux,
self.gsparams,
)
def __str__(self):
s = "A(sigma=%s" % self.sigma
if self.flux != 1.0:
s += ", flux=%s" % self.flux
s += ")"
return s
@property
def sigma(self):
return self.params["sigma"] Now one exemple that is producing what is expected def gen_gal(p):
aA = A(flux=p["flux"], sigma=p["sigma"])
return aA.flux + 10* aA.sigma
print("tst vmap gen_gal")
params = {"flux": jnp.array([1.e5, 2.e5]), "sigma": jnp.array([3.,4.])}
jax.vmap(gen_gal)(params) # DeviceArray([100030., 200040.], dtype=float64) Now an example that crash def gen_gal_img(p):
aA = A(flux=p["flux"], sigma=p["sigma"])
n = (50*aA.sigma).astype(int) # IF ONE fix to 50 this is Ok
img = jnp.zeros(shape=(n, n))
return img
print("tst vmap gen_gal_img")
params = {"flux": jnp.array([1.e5, 2.e5]), "sigma": jnp.array([3.,4.])}
jax.vmap(gen_gal_img)(params) leads to
|
Beta Was this translation helpful? Give feedback.
Answered by
soraros
Feb 16, 2023
Replies: 1 comment 3 replies
-
You seem to want to map |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
jecampagne
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You seem to want to map
gen_gal_img
over a sequence ofA
, each produce result with different shape ((n := int(50 * a.sigma), n)
), whichvmap
doesn't support.