Avoid re-computations of lazy properties after pytree flattening #10486
-
Hi all ! As part of a personal project I have to deal with parameters of multivariate Gaussian distributions a lot, and taking inspiration from NumPyro I want to use some kind of lazy evaluation with some tool like this:
which I can use like this for example:
The lazy evaluation is practical, but I can't find a way to satisfy both of the following points at the same time:
But maybe under JIT none of this matters ? I am not clear as to what kind of optimizations is done at compile time w.r.t unused computations, e.g. maybe it doesn't matter to compute "cov" if it is not used because the compilation will discard that computation anyway ? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
I think unused computation can trivially be optimized by JIT. @register_pytree_node_class
@dataclass(repr=True)
class CovParams:
scale_tril:jnp.ndarray
cov:jnp.ndarray
prec:jnp.ndarray
def __init__(self, scale_tril=None, cov=None):
if cov is not None:
self.cov = cov
self.scale_tril = jnp.linalg.cholesky(self.cov)
elif scale_tril is not None:
self.scale_tril = scale_tril
else:
raise ValueError(
"One of `covariance_matrix`, `scale_tril`"
" must be specified."
)
@lazy_property
def cov(self):
return jnp.matmul(self.scale_tril, jnp.swapaxes(self.scale_tril, -1, -2))
def tree_flatten(self):
cov = vars(self)['cov']
if not isinstance(cov, jnp.ndarray):
cov = None
return self.scale_tril, cov
@classmethod
def tree_unflatten(cls, aux_data, params):
scale_tril, cov = params
obj = cls(scale_tril=scale_tril)
if isinstance(cov, jnp.ndarray):
setattr(obj, 'cov', cov)
return obj |
Beta Was this translation helpful? Give feedback.
I think unused computation can trivially be optimized by JIT.
But I think there is some workaround as well: