You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+19-8Lines changed: 19 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,7 +2,7 @@
2
2
3
3
### Tesseract-JAX
4
4
5
-
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core)as part of [JAX](https://github.com/jax-ml/jax)programs, with full support for function transformations like JIT, `grad`, and more.
5
+
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core)that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax)primitives, and makes them jittable, differentiable, and composable.
6
6
7
7
[Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) |
8
8
[Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) |
@@ -12,7 +12,16 @@
12
12
13
13
---
14
14
15
-
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
15
+
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
16
+
17
+
```python
18
+
@jax.jit
19
+
defvector_sum(x, y):
20
+
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
21
+
return res["vector_add"]["result"].sum()
22
+
23
+
jax.grad(vector_sum)(x, y) # 🎉
24
+
```
16
25
17
26
## Quick start
18
27
@@ -31,7 +40,8 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function:
@@ -62,22 +72,23 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
62
72
63
73
vector_sum_grad = jax.grad(vector_sum)
64
74
vector_sum_grad(x, y)
65
-
```
75
+
```
66
76
67
77
> [!TIP]
68
78
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.
69
79
70
80
## Sharp edges
71
81
72
-
-**Arrays vs. array-like objects**: Tesseract-JAXist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAXor NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
82
+
-**Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
73
83
74
84
```python
75
85
from tesseract_core import Tesseract
76
86
from tesseract_jax import apply_tesseract
77
87
78
-
tess = Tesseract.from_image("vectoradd")
79
-
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
80
-
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
88
+
tess = Tesseract.from_image("vectoradd_jax")
89
+
with Tesseract.from_image("vectoradd_jax") as tess:
apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
81
92
```
82
93
-**Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
Copy file name to clipboardExpand all lines: docs/content/get-started.md
+7-9Lines changed: 7 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,9 +1,5 @@
1
1
# Get started
2
2
3
-
`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.
4
-
5
-
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
6
-
7
3
## Quick start
8
4
9
5
```{note}
@@ -23,7 +19,8 @@ For more detailed installation instructions, please refer to the [Tesseract Core
@@ -62,15 +59,16 @@ Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tess
62
59
63
60
## Sharp edges
64
61
65
-
-**Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
62
+
-**Arrays vs. array-like objects**: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
66
63
67
64
```python
68
65
from tesseract_core import Tesseract
69
66
from tesseract_jax import apply_tesseract
70
67
71
-
tess = Tesseract.from_image("vectoradd")
72
-
apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error
73
-
apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works
68
+
tess = Tesseract.from_image("vectoradd_jax")
69
+
with Tesseract.from_image("vectoradd_jax") as tess:
apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
74
72
```
75
73
-**Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined.
Copy file name to clipboardExpand all lines: docs/index.md
+13-2Lines changed: 13 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,9 +1,20 @@
1
1
# Tesseract-JAX
2
2
3
-
```{include} content/get-started.md
4
-
:start-line: 2
3
+
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.
4
+
5
+
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
6
+
7
+
```python
8
+
@jax.jit
9
+
defvector_sum(x, y):
10
+
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
11
+
return res["vector_add"]["result"].sum()
12
+
13
+
jax.grad(vector_sum)(x, y) # 🎉
5
14
```
6
15
16
+
Want to learn more? See how to [get started](content/get-started.md) with Tesseract-JAX, explore the [API reference](content/api.md), or learn by [example](demo_notebooks/simple.ipynb).
17
+
7
18
## License
8
19
9
20
Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).
0 commit comments