Skip to content
Discussion options

You must be logged in to vote

In this case you don't want to compile a PyTensor function, but transpile the underlying graph to JAX. PyMC does this when working with JAX based simplers here: https://github.com/pymc-devs/pymc/blob/340e403b8813ab5f3699a476cc828cc92c4f9d50/pymc/sampling/jax.py#L122-L133

If you have shared variables you have to make a decision how to handle them. In PyMC we treat them as constants as we know users wouldn't be updating them during the time the JAX samplers are working. But that is context-specific. If you don't have shared variables then it's just those 3 lines really:

def get_jaxified_graph(
    inputs: list[TensorVariable] | None = None,
    outputs: list[TensorVariable] | None = None,
) ->

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@ricardoV94
Comment options

@kratsg
Comment options

@ricardoV94
Comment options

@ricardoV94
Comment options

@kratsg
Comment options

Answer selected by kratsg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants