Skip to content

Commit fe917bb

Browse files
committed
doc: add rendered demos to docs
1 parent 466fb40 commit fe917bb

File tree

4 files changed

+27
-82
lines changed

4 files changed

+27
-82
lines changed

docs/conf.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
88

99
import re
10+
import shutil
11+
from pathlib import Path
1012

1113
from tesseract_jax import __version__
1214

@@ -28,7 +30,7 @@
2830
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
2931

3032
extensions = [
31-
"myst_parser",
33+
"myst_nb", # This is myst-parser + jupyter notebook support
3234
"sphinx.ext.intersphinx",
3335
"sphinx.ext.autodoc",
3436
"sphinx.ext.napoleon",
@@ -65,3 +67,15 @@
6567
"sidebar_hide_name": False,
6668
}
6769
html_css_files = ["custom.css"]
70+
71+
72+
# -- Handle Jupyter notebooks ------------------------------------------------
73+
74+
# Do not execute notebooks during build (just take existing output)
75+
nb_execution_mode = "off"
76+
77+
# Copy example notebooks to demo_notebooks folder on every build
78+
for example_notebook in Path("../demo").glob("*/demo.ipynb"):
79+
# Copy the example notebook to the docs folder
80+
dest = (Path("demo_notebooks") / example_notebook.parent.name).with_suffix(".ipynb")
81+
shutil.copyfile(example_notebook, dest)

docs/content/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# API documentation
1+
# API reference
22

33
```{eval-rst}
44
.. automodule:: tesseract_jax

docs/index.md

Lines changed: 10 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,23 @@
1-
# tesseract-jax
2-
3-
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.
4-
5-
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
6-
7-
## Quick start
8-
9-
```{note}
10-
Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+).
11-
```
12-
13-
```{seealso}
14-
For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html).
15-
```
16-
17-
1. Install Tesseract-JAX:
18-
19-
```bash
20-
$ pip install tesseract-jax
21-
```
22-
23-
2. Build an example Tesseract:
24-
25-
```bash
26-
$ tesseract build demo/simple/vectoradd_jax
27-
```
28-
29-
3. Use it as part of a JAX program:
30-
31-
```python
32-
import jax
33-
import jax.numpy as jnp
34-
from tesseract_core import Tesseract
35-
from tesseract_jax import apply_tesseract
36-
37-
# Load the Tesseract
38-
t = Tesseract.from_image("vectoradd_jax")
39-
40-
# Run it with JAX
41-
x = jnp.ones((1000, 1000))
42-
y = jnp.ones((1000, 1000))
43-
44-
def vector_add(x, y):
45-
return apply_tesseract(t, x, y)
46-
47-
vector_add(x, y) # success!
48-
49-
# You can also use it with JAX transformations like JIT and grad
50-
vector_add_jit = jax.jit(vector_add)
51-
vector_add_jit(x, y)
52-
53-
vector_add_grad = jax.grad(vector_add)
54-
vector_add_grad(x, y)
55-
```
56-
57-
```{tip}
58-
Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX.
1+
```{include} content/index.md
592
```
603

61-
## Sharp edges
62-
63-
- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
64-
65-
```python
66-
from tesseract_core import Tesseract
67-
from tesseract_jax import apply_tesseract
68-
69-
tess = Tesseract.from_image("vectoradd")
70-
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
71-
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
72-
```
73-
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
74-
75-
```{tip}
76-
When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`.
77-
```
784

795
```{toctree}
806
:caption: Contents
817
:maxdepth: 2
828
:hidden:
839
10+
content/index
8411
content/api
85-
Tesseract Core documentation <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/>
12+
Tesseract Core docs <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/>
8613
Tesseract User Forums <https://si-tesseract.discourse.group/>
8714
```
8815

89-
## License
90-
91-
Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).
16+
```{toctree}
17+
:caption: Demos
18+
:maxdepth: 2
19+
:hidden:
9220
93-
Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.
21+
demo_notebooks/simple.ipynb
22+
demo_notebooks/cfd.ipynb
23+
```

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ docs = [
2222
"sphinx_autodoc_typehints",
2323
"furo",
2424
"myst-parser",
25+
"myst-nb",
2526
]
2627
# TODO: add dev dependencies here *and* in requirements-dev.txt
2728
dev = [

0 commit comments

Comments
 (0)