Skip to content

Commit 5289a63

Browse files
committed
Merge branch 'main' into jacan/demos
2 parents f66647e + 613c84d commit 5289a63

File tree

4 files changed

+47
-23
lines changed

4 files changed

+47
-23
lines changed

README.md

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
### Tesseract-JAX
44

5-
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of [JAX](https://github.com/jax-ml/jax) programs, with full support for function transformations like JIT, `grad`, and more.
5+
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.
66

77
[Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) |
88
[Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) |
@@ -12,7 +12,16 @@
1212

1313
---
1414

15-
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
15+
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
16+
17+
```python
18+
@jax.jit
19+
def vector_sum(x, y):
20+
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
21+
return res["vector_add"]["result"].sum()
22+
23+
jax.grad(vector_sum)(x, y) # 🎉
24+
```
1625

1726
## Quick start
1827

@@ -31,7 +40,8 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
3140
2. Build an example Tesseract:
3241

3342
```bash
34-
$ tesseract build examples/simple/vectoradd_jax
43+
$ git clone https://github.com/pasteurlabs/tesseract-jax
44+
$ tesseract build tesseract-jax/examples/simple/vectoradd_jax
3545
```
3646

3747
3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function:
@@ -62,22 +72,23 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
6272

6373
vector_sum_grad = jax.grad(vector_sum)
6474
vector_sum_grad(x, y)
65-
```
75+
```
6676

6777
> [!TIP]
6878
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.
6979
7080
## Sharp edges
7181

72-
- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
82+
- **Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
7383

7484
```python
7585
from tesseract_core import Tesseract
7686
from tesseract_jax import apply_tesseract
7787

78-
tess = Tesseract.from_image("vectoradd")
79-
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
80-
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
88+
tess = Tesseract.from_image("vectoradd_jax")
89+
with Tesseract.from_image("vectoradd_jax") as tess:
90+
apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error
91+
apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
8192
```
8293
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
8394

docs/content/get-started.md

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
# Get started
22

3-
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.
4-
5-
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
6-
73
## Quick start
84

95
```{note}
@@ -23,7 +19,8 @@ For more detailed installation instructions, please refer to the [Tesseract Core
2319
2. Build an example Tesseract:
2420

2521
```bash
26-
$ tesseract build examples/simple/vectoradd_jax
22+
$ git clone https://github.com/pasteurlabs/tesseract-jax
23+
$ tesseract build tesseract-jax/examples/simple/vectoradd_jax
2724
```
2825

2926
3. Use it as part of a JAX program:
@@ -62,15 +59,16 @@ Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tess
6259

6360
## Sharp edges
6461

65-
- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
62+
- **Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
6663

6764
```python
6865
from tesseract_core import Tesseract
6966
from tesseract_jax import apply_tesseract
7067

71-
tess = Tesseract.from_image("vectoradd")
72-
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
73-
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
68+
tess = Tesseract.from_image("vectoradd_jax")
69+
with Tesseract.from_image("vectoradd_jax") as tess:
70+
apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error
71+
apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
7472
```
7573
- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
7674

docs/index.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
# Tesseract-JAX
22

3-
```{include} content/get-started.md
4-
:start-line: 2
3+
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.
4+
5+
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
6+
7+
```python
8+
@jax.jit
9+
def vector_sum(x, y):
10+
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
11+
return res["vector_add"]["result"].sum()
12+
13+
jax.grad(vector_sum)(x, y) # 🎉
514
```
615

16+
Want to learn more? See how to [get started](content/get-started.md) with Tesseract-JAX, explore the [API reference](content/api.md), or learn by [example](demo_notebooks/simple.ipynb).
17+
718
## License
819

920
Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).

examples/simple/vectoradd_jax/tesseract_api.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,20 @@ def apply_jit(inputs: dict) -> dict:
6565
b_scaled = inputs["b"]["s"] * inputs["b"]["v"]
6666
add_result = a_scaled + b_scaled
6767
min_result = a_scaled - b_scaled
68+
69+
def safe_norm(x, ord):
70+
# Compute the norm of a vector, adding a small epsilon to ensure
71+
# differentiability and avoid division by zero
72+
return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord)
73+
6874
return {
6975
"vector_add": {
7076
"result": add_result,
71-
"normed_result": add_result
72-
/ jnp.linalg.norm(add_result, ord=inputs["norm_ord"]),
77+
"normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]),
7378
},
7479
"vector_min": {
7580
"result": min_result,
76-
"normed_result": min_result
77-
/ jnp.linalg.norm(min_result, ord=inputs["norm_ord"]),
81+
"normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]),
7882
},
7983
}
8084

0 commit comments

Comments
 (0)