how to use pytensor with jitted functions in jax (and tracer arrays) #1636
-
I'm working on a project that builds likelihood functions, using import pyhs3
import jax
import jax.numpy as jnp
ws = pyhs3.Workspace.load('tests/test_pdf/rf501_simultaneouspdf.json')
model = ws.model(mode='jax')
@jax.jit
def nll(pars):
return -2 * model.logpdf('model', x=jnp.array(5.0), **pars)
# all your parameters...
pars = jax.tree.map(jnp.asarray, {'f': 0.2, 'mean': 0.0, 'sigma': 0.3, 'mean2': 0.0, 'sigma2': 0.3})
print("NLL", nll(pars)) will fail with
because in pytensor/pytensor/tensor/type.py Line 153 in 1d825dd np.array / np.asarray calls this makes things not work well with jax.
model.distributions['model']
# Add.0
model._compiled_functions['model']
# <pytensor.compile.function.types.Function object at 0x110913380>
pp(model.distributions['model'])
# '((0.0 + (Clip(f, 0.0, 1.0) * ((1.0 / (Sqrt(6.283185307179586) * Clip(sigma, 0.1, 10.0))) * Exp((-0.5 * (((Clip(x, -8.0, 8.0) - Clip(mean, -8.0, 8.0)) / Clip(sigma, 0.1, 10.0)) ** 2)))))) + ((1 - (0.0 + Clip(f, 0.0, 1.0))) * ((1.0 / (Sqrt(6.283185307179586) * Clip(sigma2, 0.1, 10.0))) * Exp((-0.5 * (((Clip(x, -8.0, 8.0) - Clip(mean2, -3.0, 3.0)) / Clip(sigma2, 0.1, 10.0)) ** 2))))))' where the function is compiled like self._compiled_functions[name] = cast(
Callable[..., npt.NDArray[np.float64]],
function(
inputs=inputs,
outputs=combined_expression,
mode=compilation_mode,
on_unused_input="ignore",
),
) so I guess it's not clear how to compile something for ./cc @pfackeldey |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
In this case you don't want to compile a PyTensor function, but transpile the underlying graph to JAX. PyMC does this when working with JAX based simplers here: https://github.com/pymc-devs/pymc/blob/340e403b8813ab5f3699a476cc828cc92c4f9d50/pymc/sampling/jax.py#L122-L133 If you have shared variables you have to make a decision how to handle them. In PyMC we treat them as constants as we know users wouldn't be updating them during the time the JAX samplers are working. But that is context-specific. If you don't have shared variables then it's just those 3 lines really: def get_jaxified_graph(
inputs: list[TensorVariable] | None = None,
outputs: list[TensorVariable] | None = None,
) -> Callable[[list[TensorVariable]], list[TensorVariable]]:
"""Compile a PyTensor graph into an optimized JAX function."""
graph = _replace_shared_variables(outputs) if outputs is not None else None
fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
mode.JAX.optimizer.rewrite(fgraph)
# We now jaxify the optimized fgraph
return jax_funcify(fgraph) If you want to define the graph dynamically when jax calls JIT you'll have to do some extra work to convert the JAX traced arrays to pytensor variables, define the pytensor graph and then return the jax graph. Couldn't tell whether you were in that context from your description |
Beta Was this translation helpful? Give feedback.
In this case you don't want to compile a PyTensor function, but transpile the underlying graph to JAX. PyMC does this when working with JAX based simplers here: https://github.com/pymc-devs/pymc/blob/340e403b8813ab5f3699a476cc828cc92c4f9d50/pymc/sampling/jax.py#L122-L133
If you have shared variables you have to make a decision how to handle them. In PyMC we treat them as constants as we know users wouldn't be updating them during the time the JAX samplers are working. But that is context-specific. If you don't have shared variables then it's just those 3 lines really: