Skip to content

UCL/jaxpr-tree

Repository files navigation

Jaxpr tree

pre-commit Tests status Linting status Documentation status License

Utility for representing JAX expressions (jaxprs) as a tree.

Builds a tree from a JAX jaxpr where each node is a primitive call; primitives that nest a jaxpr (e.g. jit, while, cond, scan) have children from their nested jaxprs. Supports computing the number of non-structural (e.g. non-JIT) primitive operations in the subtree of each node which can be used to compare if the size of the traced computation changes with the input sizes.

Warning

This tool was created with the help of large language model coding assistants as a quick prototype to aid in trying to diagnose what was causing the JIT compilation time to grow input size in a particular project. While some basic verification of the code has been performed there are no guarantees it will give sensible results in all cases.

Getting started

Prerequisites

jaxpr-tree requires Python 3.12+.

Installation

We recommend installing in a project specific virtual environment created using a environment management tool such as uv. To install the latest development version of jaxpr-tree using uv in the currently active environment run

uv pip install git+https://github.com/UCL/jaxpr-tree.git

Alternatively create a local clone of the repository with

git clone https://github.com/UCL/jaxpr-tree.git

and then install in editable mode by running

uv pip install -e .

Example usage

import jax
import jaxpr_tree

@jax.jit
def inner(x):
    return jax.numpy.cos(x)

def outer(x):
    return jax.lax.cond(
        x < 0,
        lambda: 2 * inner(x) + 1.0,
        lambda: -x
    )

x = 0.5
jaxpr = jax.make_jaxpr(outer)(x)
tree = jaxpr_tree.jaxpr_to_tree(jaxpr)
tree.compute_subtree_leaf_counts()
print(tree)

outputs

(root)  leaves=6
    ├── lt
    ├── convert_element_type
    └── cond  leaves=4
        ├── (param block)  branches[0]
        │   └── neg
        └── (param block)  branches[1]  leaves=3
            ├── jit
            │   └── (param block)  jaxpr
            │       └── cos
            ├── mul
            └── add

Running tests

Tests can be run across all compatible Python versions in isolated environments using tox by running

tox

To run tests manually in a Python environment with pytest installed run

pytest tests

again from the root of the repository.

Building documentation

The MkDocs HTML documentation can be built locally by running

tox -e docs

from the root of the repository. The built documentation will be written to site.

Alternatively to build and preview the documentation locally, in a Python environment with the optional docs dependencies installed, run

mkdocs serve

About

Utility for representing JAX expressions (jaxprs) as a tree

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages