Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
# test with oldest supported Python version only (for slow tests)
python-version: ["3.10"]

demo:
example:
- simple
- cfd

Expand Down Expand Up @@ -46,7 +46,7 @@ jobs:
uv sync --extra dev --frozen

- name: Run example
working-directory: demo/${{matrix.demo}}
working-directory: examples/${{matrix.example}}
run: |
uv pip install jupyter
uv run --no-sync jupyter nbconvert --to notebook --execute demo.ipynb
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of [JAX](https://github.com/jax-ml/jax) programs, with full support for function transformations like JIT, `grad`, and more.

[Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) |
[Explore the demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) |
[Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) |
[Report an issue](https://github.com/pasteurlabs/tesseract-jax/issues) |
[Talk to the community](https://si-tesseract.discourse.group/) |
[Contribute](CONTRIBUTING.md)
Expand All @@ -31,7 +31,7 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
2. Build an example Tesseract:

```bash
$ tesseract build demo/simple/vectoradd_jax
$ tesseract build examples/simple/vectoradd_jax
```

3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function:
Expand All @@ -44,26 +44,28 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser

# Load the Tesseract
t = Tesseract.from_image("vectoradd_jax")
t.serve()

# Run it with JAX
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
x = jnp.ones((1000,))
y = jnp.ones((1000,))

def vector_add(x, y):
return apply_tesseract(t, x, y)
def vector_sum(x, y):
res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()

vector_add(x, y) # success!
vector_sum(x, y) # success!

# You can also use it with JAX transformations like JIT and grad
vector_add_jit = jax.jit(vector_add)
vector_add_jit(x, y)
# You can also use it with JAX transformations like JIT and grad
vector_sum_jit = jax.jit(vector_sum)
vector_sum_jit(x, y)

vector_add_grad = jax.grad(vector_add)
vector_add_grad(x, y)
vector_sum_grad = jax.grad(vector_sum)
vector_sum_grad(x, y)
```

> [!TIP]
> Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX.
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.

## Sharp edges

Expand Down
5 changes: 0 additions & 5 deletions demo/cfd/Readme.md

This file was deleted.

Empty file removed demo/simple/README.md
Empty file.
28 changes: 23 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

import re
import shutil
from pathlib import Path

from tesseract_jax import __version__

project = "tesseract-jax"
copyright = "2025, The tesseract-jax Team @ Pasteur Labs"
author = "The tesseract-jax Team @ Pasteur Labs"
project = "Tesseract-JAX"
copyright = "2025, Pasteur Labs"
author = "The Tesseract-JAX Team @ Pasteur Labs + OSS contributors"

# The short X.Y version
parsed_version = re.match(r"(\d+\.\d+\.\d+)", __version__)
Expand All @@ -28,12 +30,16 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"myst_parser",
"myst_nb", # This is myst-parser + jupyter notebook support
"sphinx.ext.intersphinx",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_autodoc_typehints",
# Copy button for code blocks
"sphinx_copybutton",
# OpenGraph metadata for social media sharing
"sphinxext.opengraph",
]

myst_enable_extensions = [
Expand Down Expand Up @@ -62,6 +68,18 @@
html_theme_options = {
"light_logo": "logo-light.png",
"dark_logo": "logo-dark.png",
"sidebar_hide_name": False,
"sidebar_hide_name": True,
}
html_css_files = ["custom.css"]


# -- Handle Jupyter notebooks ------------------------------------------------

# Do not execute notebooks during build (just take existing output)
nb_execution_mode = "off"

# Copy example notebooks to demo_notebooks folder on every build
for example_notebook in Path("../examples").glob("*/demo.ipynb"):
# Copy the example notebook to the docs folder
dest = (Path("demo_notebooks") / example_notebook.parent.name).with_suffix(".ipynb")
shutil.copyfile(example_notebook, dest)
2 changes: 1 addition & 1 deletion docs/content/api.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# API documentation
# API reference

```{eval-rst}
.. automodule:: tesseract_jax
Expand Down
79 changes: 79 additions & 0 deletions docs/content/get-started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Get started

`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.

## Quick start

```{note}
Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+).
```

```{seealso}
For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html).
```

1. Install Tesseract-JAX:

```bash
$ pip install tesseract-jax
```

2. Build an example Tesseract:

```bash
$ tesseract build examples/simple/vectoradd_jax
```

3. Use it as part of a JAX program:

```python
import jax
import jax.numpy as jnp
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

# Load the Tesseract
t = Tesseract.from_image("vectoradd_jax")
t.serve()

# Run it with JAX
x = jnp.ones((1000,))
y = jnp.ones((1000,))

def vector_sum(x, y):
res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()

vector_sum(x, y) # success!

# You can also use it with JAX transformations like JIT and grad
vector_sum_jit = jax.jit(vector_sum)
vector_sum_jit(x, y)

vector_sum_grad = jax.grad(vector_sum)
vector_sum_grad(x, y)
```

```{tip}
Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for ways to use Tesseract-JAX.
```

## Sharp edges

- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.

```python
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

tess = Tesseract.from_image("vectoradd")
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
```
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.

```{tip}
When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`.
```
1 change: 1 addition & 0 deletions docs/demo_notebooks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb
94 changes: 20 additions & 74 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,93 +1,39 @@
# tesseract-jax
# Tesseract-JAX

`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.

## Quick start

```{note}
Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+).
```

```{seealso}
For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html).
```{include} content/get-started.md
:start-line: 2
```

1. Install Tesseract-JAX:

```bash
$ pip install tesseract-jax
```

2. Build an example Tesseract:

```bash
$ tesseract build demo/simple/vectoradd_jax
```

3. Use it as part of a JAX program:

```python
import jax
import jax.numpy as jnp
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

# Load the Tesseract
t = Tesseract.from_image("vectoradd_jax")

# Run it with JAX
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
## License

def vector_add(x, y):
return apply_tesseract(t, x, y)
Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).

vector_add(x, y) # success!
Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.

# You can also use it with JAX transformations like JIT and grad
vector_add_jit = jax.jit(vector_add)
vector_add_jit(x, y)

vector_add_grad = jax.grad(vector_add)
vector_add_grad(x, y)
```
```{toctree}
:caption: Usage
:maxdepth: 2
:hidden:

```{tip}
Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX.
content/get-started
content/api
```

## Sharp edges

- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.

```python
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

tess = Tesseract.from_image("vectoradd")
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
```
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
```{toctree}
:caption: Examples
:maxdepth: 2
:hidden:

```{tip}
When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`.
demo_notebooks/simple.ipynb
demo_notebooks/cfd.ipynb
```

```{toctree}
:caption: Contents
:caption: See also
:maxdepth: 2
:hidden:

content/api
Tesseract Core documentation <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/>
Tesseract Core docs <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/>
Tesseract User Forums <https://si-tesseract.discourse.group/>
```

## License

Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).

Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.
Binary file modified docs/static/logo-dark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/static/logo-light.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Tesseract-JAX examples

This directory contains example Tesseract configurations, notebooks. and scripts demonstrating how to use Tesseract-JAX in various contexts. Each example is self-contained and can be run independently.

## Examples

- [Simple](simple/demo.ipynb): A basic example of using Tesseract-JAX with a simple vector addition task. It demonstrates how to build a Tesseract and execute it within JAX.
- [CFD](cfd/demo.ipynb): A more complex example demonstrating how to use Tesseract-JAX to differentiate through a computational fluid dynamics (CFD) simulation Tesseract.
Loading