Utility for representing JAX expressions (jaxprs) as a tree.
Builds a tree from a JAX jaxpr where each node is a primitive call; primitives that
nest a jaxpr (e.g. jit, while, cond, scan) have children from their nested jaxprs.
Supports computing the number of non-structural (e.g. non-JIT) primitive operations
in the subtree of each node which can be used to compare if the size of the traced
computation changes with the input sizes.
Warning
This tool was created with the help of large language model coding assistants as a quick prototype to aid in trying to diagnose what was causing the JIT compilation time to grow input size in a particular project. While some basic verification of the code has been performed there are no guarantees it will give sensible results in all cases.
jaxpr-tree requires Python 3.12+.
We recommend installing in a project specific virtual environment created using
a environment management tool such as
uv. To install the latest
development version of jaxpr-tree using uv in the currently active
environment run
uv pip install git+https://github.com/UCL/jaxpr-tree.gitAlternatively create a local clone of the repository with
git clone https://github.com/UCL/jaxpr-tree.gitand then install in editable mode by running
uv pip install -e .import jax
import jaxpr_tree
@jax.jit
def inner(x):
return jax.numpy.cos(x)
def outer(x):
return jax.lax.cond(
x < 0,
lambda: 2 * inner(x) + 1.0,
lambda: -x
)
x = 0.5
jaxpr = jax.make_jaxpr(outer)(x)
tree = jaxpr_tree.jaxpr_to_tree(jaxpr)
tree.compute_subtree_leaf_counts()
print(tree)outputs
(root) leaves=6
├── lt
├── convert_element_type
└── cond leaves=4
├── (param block) branches[0]
│ └── neg
└── (param block) branches[1] leaves=3
├── jit
│ └── (param block) jaxpr
│ └── cos
├── mul
└── add
Tests can be run across all compatible Python versions in isolated environments
using tox by running
toxTo run tests manually in a Python environment with pytest installed run
pytest testsagain from the root of the repository.
The MkDocs HTML documentation can be built locally by running
tox -e docsfrom the root of the repository. The built documentation will be written to
site.
Alternatively to build and preview the documentation locally, in a Python
environment with the optional docs dependencies installed, run
mkdocs serve