Skip to content

Commit dfe2e23

Browse files
authored
Merge branch 'main' into dion/publish
2 parents 12e90a4 + 8185cd1 commit dfe2e23

29 files changed

+2282
-1546
lines changed

.github/workflows/test_examples.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
# test with oldest supported Python version only (for slow tests)
1717
python-version: ["3.10"]
1818

19-
demo:
19+
example:
2020
- simple
2121
- cfd
2222

@@ -46,7 +46,7 @@ jobs:
4646
uv sync --extra dev --frozen
4747
4848
- name: Run example
49-
working-directory: demo/${{matrix.demo}}
49+
working-directory: examples/${{matrix.example}}
5050
run: |
5151
uv pip install jupyter
5252
uv run --no-sync jupyter nbconvert --to notebook --execute demo.ipynb

README.md

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,26 @@
22

33
### Tesseract-JAX
44

5-
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, and more.
5+
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.
66

77
[Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) |
8-
[Explore the demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) |
8+
[Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) |
99
[Report an issue](https://github.com/pasteurlabs/tesseract-jax/issues) |
1010
[Talk to the community](https://si-tesseract.discourse.group/) |
1111
[Contribute](CONTRIBUTING.md)
1212

1313
---
1414

15-
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
15+
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
16+
17+
```python
18+
@jax.jit
19+
def vector_sum(x, y):
20+
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
21+
return res["vector_add"]["result"].sum()
22+
23+
jax.grad(vector_sum)(x, y) # 🎉
24+
```
1625

1726
## Quick start
1827

@@ -31,7 +40,8 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
3140
2. Build an example Tesseract:
3241

3342
```bash
34-
$ tesseract build demo/simple/vectoradd_jax
43+
$ git clone https://github.com/pasteurlabs/tesseract-jax
44+
$ tesseract build tesseract-jax/examples/simple/vectoradd_jax
3545
```
3646

3747
3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function:
@@ -44,38 +54,41 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
4454

4555
# Load the Tesseract
4656
t = Tesseract.from_image("vectoradd_jax")
57+
t.serve()
4758

4859
# Run it with JAX
49-
x = jnp.ones((1000, 1000))
50-
y = jnp.ones((1000, 1000))
60+
x = jnp.ones((1000,))
61+
y = jnp.ones((1000,))
5162

52-
def vector_add(x, y):
53-
return apply_tesseract(t, x, y)
63+
def vector_sum(x, y):
64+
res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
65+
return res["vector_add"]["result"].sum()
5466

55-
vector_add(x, y) # success!
67+
vector_sum(x, y) # success!
5668

57-
# You can also use it with JAX transformations like JIT and grad
58-
vector_add_jit = jax.jit(vector_add)
59-
vector_add_jit(x, y)
69+
# You can also use it with JAX transformations like JIT and grad
70+
vector_sum_jit = jax.jit(vector_sum)
71+
vector_sum_jit(x, y)
6072

61-
vector_add_grad = jax.grad(vector_add)
62-
vector_add_grad(x, y)
63-
```
73+
vector_sum_grad = jax.grad(vector_sum)
74+
vector_sum_grad(x, y)
75+
```
6476

6577
> [!TIP]
66-
> Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX.
78+
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.
6779
6880
## Sharp edges
6981

70-
- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
82+
- **Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
7183

7284
```python
7385
from tesseract_core import Tesseract
7486
from tesseract_jax import apply_tesseract
7587

76-
tess = Tesseract.from_image("vectoradd")
77-
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
78-
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
88+
tess = Tesseract.from_image("vectoradd_jax")
89+
with Tesseract.from_image("vectoradd_jax") as tess:
90+
apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error
91+
apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
7992
```
8093
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
8194

@@ -84,6 +97,6 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
8497
8598
## License
8699

87-
Tesseract JAX is licensed under the [Apache License 2.0](LICENSE) and is free to use, modify, and distribute (under the terms of the license).
100+
Tesseract-JAX is licensed under the [Apache License 2.0](LICENSE) and is free to use, modify, and distribute (under the terms of the license).
88101

89102
Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.

demo/cfd/Readme.md

Lines changed: 0 additions & 5 deletions
This file was deleted.

demo/cfd/cfd-tesseract/tesseract_api.py

Lines changed: 0 additions & 185 deletions
This file was deleted.

0 commit comments

Comments
 (0)