|
1 | | -# tesseract-jax |
2 | | - |
3 | | -`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more. |
4 | | - |
5 | | -The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines. |
6 | | - |
7 | | -## Quick start |
8 | | - |
9 | | -```{note} |
10 | | -Before proceeding, make sure you have a [working installation of Docker](https://docs.docker.com/engine/install/) and a modern Python installation (Python 3.10+). |
11 | | -``` |
12 | | - |
13 | | -```{seealso} |
14 | | -For more detailed installation instructions, please refer to the [Tesseract Core documentation](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/introduction/installation.html). |
15 | | -``` |
16 | | - |
17 | | -1. Install Tesseract-JAX: |
18 | | - |
19 | | - ```bash |
20 | | - $ pip install tesseract-jax |
21 | | - ``` |
22 | | - |
23 | | -2. Build an example Tesseract: |
24 | | - |
25 | | - ```bash |
26 | | - $ tesseract build demo/simple/vectoradd_jax |
27 | | - ``` |
28 | | - |
29 | | -3. Use it as part of a JAX program: |
30 | | - |
31 | | - ```python |
32 | | - import jax |
33 | | - import jax.numpy as jnp |
34 | | - from tesseract_core import Tesseract |
35 | | - from tesseract_jax import apply_tesseract |
36 | | - |
37 | | - # Load the Tesseract |
38 | | - t = Tesseract.from_image("vectoradd_jax") |
39 | | - |
40 | | - # Run it with JAX |
41 | | - x = jnp.ones((1000, 1000)) |
42 | | - y = jnp.ones((1000, 1000)) |
43 | | - |
44 | | - def vector_add(x, y): |
45 | | - return apply_tesseract(t, x, y) |
46 | | - |
47 | | - vector_add(x, y) # success! |
48 | | - |
49 | | - # You can also use it with JAX transformations like JIT and grad |
50 | | - vector_add_jit = jax.jit(vector_add) |
51 | | - vector_add_jit(x, y) |
52 | | - |
53 | | - vector_add_grad = jax.grad(vector_add) |
54 | | - vector_add_grad(x, y) |
55 | | - ``` |
56 | | - |
57 | | -```{tip} |
58 | | -Now you're ready to jump into our [demos](https://github.com/pasteurlabs/tesseract-jax/tree/main/demo) for more examples on how to use Tesseract-JAX. |
| 1 | +```{include} content/index.md |
59 | 2 | ``` |
60 | 3 |
|
61 | | -## Sharp edges |
62 | | - |
63 | | -- **Arrays vs. array-like objects**: Tesseract-JAX ist stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values. |
64 | | - |
65 | | - ```python |
66 | | - from tesseract_core import Tesseract |
67 | | - from tesseract_jax import apply_tesseract |
68 | | - |
69 | | - tess = Tesseract.from_image("vectoradd") |
70 | | - apply_tesseract(tess, {"a": 1.0, "b": 2.0}) # ❌ raises an error |
71 | | - apply_tesseract(tess, {"a": jnp.array(1.0), "b": jnp.array(2.0)}) # ✅ works |
72 | | - ``` |
73 | | -- **Additional required endpoints**: Tesseract-JAX requires the [`abstract_eval`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#abstract-eval) Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like `jax.grad` require [`vector_jacobian_product`](https://docs.pasteurlabs.ai/projects/tesseract-core/latest/content/api/endpoints.html#vector-jacobian-product) to be defined. |
74 | | - |
75 | | -```{tip} |
76 | | -When creating a new Tesseract based on a JAX function, use `tesseract init --recipe jax` to define all required endpoints automatically, including `abstract_eval` and `vector_jacobian_product`. |
77 | | -``` |
78 | 4 |
|
79 | 5 | ```{toctree} |
80 | 6 | :caption: Contents |
81 | 7 | :maxdepth: 2 |
82 | 8 | :hidden: |
83 | 9 |
|
| 10 | +content/index |
84 | 11 | content/api |
85 | | -Tesseract Core documentation <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/> |
| 12 | +Tesseract Core docs <https://docs.pasteurlabs.ai/projects/tesseract-core/latest/> |
86 | 13 | Tesseract User Forums <https://si-tesseract.discourse.group/> |
87 | 14 | ``` |
88 | 15 |
|
89 | | -## License |
90 | | - |
91 | | -Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license). |
| 16 | +```{toctree} |
| 17 | +:caption: Demos |
| 18 | +:maxdepth: 2 |
| 19 | +:hidden: |
92 | 20 |
|
93 | | -Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission. |
| 21 | +demo_notebooks/simple.ipynb |
| 22 | +demo_notebooks/cfd.ipynb |
| 23 | +``` |
0 commit comments