Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
# test with oldest supported Python version only (for slow tests)
python-version: ["3.10"]

demo:
example:
- simple
- cfd

Expand Down Expand Up @@ -46,7 +46,7 @@ jobs:
uv sync --extra dev --frozen

- name: Run example
working-directory: demo/${{matrix.demo}}
working-directory: examples/${{matrix.example}}
run: |
uv pip install jupyter
uv run --no-sync jupyter nbconvert --to notebook --execute demo.ipynb
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of [JAX](https://github.com/jax-ml/jax) programs, with full support for function transformations like JIT, `grad`, and more.

[Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) |
[Explore the demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) |
[Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) |
[Report an issue](https://github.com/pasteurlabs/tesseract-jax/issues) |
[Talk to the community](https://si-tesseract.discourse.group/) |
[Contribute](CONTRIBUTING.md)
Expand All @@ -31,7 +31,7 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
2. Build an example Tesseract:

```bash
$ tesseract build demo/simple/vectoradd_jax
$ tesseract build examples/simple/vectoradd_jax
```

3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function:
Expand All @@ -44,26 +44,28 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser

# Load the Tesseract
t = Tesseract.from_image("vectoradd_jax")
t.serve()

# Run it with JAX
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
x = jnp.ones((1000,))
y = jnp.ones((1000,))

def vector_add(x, y):
return apply_tesseract(t, x, y)
def vector_sum(x, y):
res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()

vector_add(x, y) # success!
vector_sum(x, y) # success!

# You can also use it with JAX transformations like JIT and grad
vector_add_jit = jax.jit(vector_add)
vector_add_jit(x, y)
# You can also use it with JAX transformations like JIT and grad
vector_sum_jit = jax.jit(vector_sum)
vector_sum_jit(x, y)

vector_add_grad = jax.grad(vector_add)
vector_add_grad(x, y)
vector_sum_grad = jax.grad(vector_sum)
vector_sum_grad(x, y)
```

> [!TIP]
> Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX.
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.

## Sharp edges

Expand Down
5 changes: 0 additions & 5 deletions demo/cfd/Readme.md

This file was deleted.

Empty file removed demo/simple/README.md
Empty file.
28 changes: 23 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

import re
import shutil
from pathlib import Path

from tesseract_jax import __version__

project = "tesseract-jax"
copyright = "2025, The tesseract-jax Team @ Pasteur Labs"
author = "The tesseract-jax Team @ Pasteur Labs"
project = "Tesseract-JAX"
copyright = "2025, Pasteur Labs"
author = "The Tesseract-JAX Team @ Pasteur Labs + OSS contributors"

# The short X.Y version
parsed_version = re.match(r"(\d+\.\d+\.\d+)", __version__)
Expand All @@ -28,12 +30,16 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"myst_parser",
"myst_nb", # This is myst-parser + jupyter notebook support
"sphinx.ext.intersphinx",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_autodoc_typehints",
# Copy button for code blocks
"sphinx_copybutton",
# OpenGraph metadata for social media sharing
"sphinxext.opengraph",
]

myst_enable_extensions = [
Expand Down Expand Up @@ -62,6 +68,18 @@
html_theme_options = {
"light_logo": "logo-light.png",
"dark_logo": "logo-dark.png",
"sidebar_hide_name": False,
"sidebar_hide_name": True,
}
html_css_files = ["custom.css"]


# -- Handle Jupyter notebooks ------------------------------------------------

# Do not execute notebooks during build (just take existing output)
nb_execution_mode = "off"

# Copy example notebooks to demo_notebooks folder on every build
for example_notebook in Path("../examples").glob("*/demo.ipynb"):
# Copy the example notebook to the docs folder
dest = (Path("demo_notebooks") / example_notebook.parent.name).with_suffix(".ipynb")
shutil.copyfile(example_notebook, dest)
2 changes: 1 addition & 1 deletion docs/content/api.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# API documentation
# API reference

```{eval-rst}
.. automodule:: tesseract_jax
Expand Down
79 changes: 79 additions & 0 deletions docs/content/get-started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Get started

`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.

## Quick start

```{note}
Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+).
```

```{seealso}
For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html).
```

1. Install Tesseract-JAX:

```bash
$ pip install tesseract-jax
```

2. Build an example Tesseract:

```bash
$ tesseract build examples/simple/vectoradd_jax
```

3. Use it as part of a JAX program:

```python
import jax
import jax.numpy as jnp
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

# Load the Tesseract
t = Tesseract.from_image("vectoradd_jax")
t.serve()

# Run it with JAX
x = jnp.ones((1000,))
y = jnp.ones((1000,))

def vector_sum(x, y):
res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()

vector_sum(x, y) # success!

# You can also use it with JAX transformations like JIT and grad
vector_sum_jit = jax.jit(vector_sum)
vector_sum_jit(x, y)

vector_sum_grad = jax.grad(vector_sum)
vector_sum_grad(x, y)
```

```{tip}
Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for ways to use Tesseract-JAX.
```

## Sharp edges

- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.

```python
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

tess = Tesseract.from_image("vectoradd")
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
```
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.

```{tip}
When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`.
```
1 change: 1 addition & 0 deletions docs/demo_notebooks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb
94 changes: 20 additions & 74 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,93 +1,39 @@
# tesseract-jax
# Tesseract-JAX

`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.

## Quick start

```{note}
Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+).
```

```{seealso}
For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html).
```{include} content/get-started.md
:start-line: 2
```

1. Install Tesseract-JAX:

```bash
$ pip install tesseract-jax
```

2. Build an example Tesseract:

```bash
$ tesseract build demo/simple/vectoradd_jax
```

3. Use it as part of a JAX program:

```python
import jax
import jax.numpy as jnp
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

# Load the Tesseract
t = Tesseract.from_image("vectoradd_jax")

# Run it with JAX
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
## License

def vector_add(x, y):
return apply_tesseract(t, x, y)
Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).

vector_add(x, y) # success!
Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.

# You can also use it with JAX transformations like JIT and grad
vector_add_jit = jax.jit(vector_add)
vector_add_jit(x, y)

vector_add_grad = jax.grad(vector_add)
vector_add_grad(x, y)
```
```{toctree}
:caption: Usage
:maxdepth: 2
:hidden:

```{tip}
Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX.
content/get-started
content/api
```

## Sharp edges

- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.

```python
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

tess = Tesseract.from_image("vectoradd")
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
```
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
```{toctree}
:caption: Examples
:maxdepth: 2
:hidden:

```{tip}
When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`.
demo_notebooks/simple.ipynb
demo_notebooks/cfd.ipynb
```

```{toctree}
:caption: Contents
:caption: See also
:maxdepth: 2
:hidden:

content/api
Tesseract Core documentation <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/>
Tesseract Core docs <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/>
Tesseract User Forums <https://si-tesseract.discourse.group/>
```

## License

Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).

Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.
Binary file modified docs/static/logo-dark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/static/logo-light.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Tesseract-JAX examples

This directory contains example Tesseract configurations, notebooks. and scripts demonstrating how to use Tesseract-JAX in various contexts. Each example is self-contained and can be run independently.

## Examples

- [Simple](simple/demo.ipynb): A basic example of using Tesseract-JAX with a simple vector addition task. It demonstrates how to build a Tesseract and execute it within JAX.
- [CFD](cfd/demo.ipynb): A more complex example demonstrating how to use Tesseract-JAX to differentiate through a computational fluid dynamics (CFD) simulation Tesseract.
Loading