Skip to content

Commit 613c84d

Browse files
doc: ensure gradients of example tesseract are well-behaved (#14)
#### Relevant issue or PR n/a #### Description of changes Ensure example Tesseract is actually differentiable everywhere (even at zero). Without this, gradients returned `nan` when evaluated for zero-vectors. #### Testing done manual #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract-jax/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Dion Häfner <[email protected]> --------- Co-authored-by: Andrei Paleyes <[email protected]>
1 parent d4e1208 commit 613c84d

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jax.grad(vector_sum)(x, y) # 🎉
7272

7373
vector_sum_grad = jax.grad(vector_sum)
7474
vector_sum_grad(x, y)
75-
```
75+
```
7676

7777
> [!TIP]
7878
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.

examples/simple/vectoradd_jax/tesseract_api.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,20 @@ def apply_jit(inputs: dict) -> dict:
6565
b_scaled = inputs["b"]["s"] * inputs["b"]["v"]
6666
add_result = a_scaled + b_scaled
6767
min_result = a_scaled - b_scaled
68+
69+
def safe_norm(x, ord):
70+
# Compute the norm of a vector, adding a small epsilon to ensure
71+
# differentiability and avoid division by zero
72+
return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord)
73+
6874
return {
6975
"vector_add": {
7076
"result": add_result,
71-
"normed_result": add_result
72-
/ jnp.linalg.norm(add_result, ord=inputs["norm_ord"]),
77+
"normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]),
7378
},
7479
"vector_min": {
7580
"result": min_result,
76-
"normed_result": min_result
77-
/ jnp.linalg.norm(min_result, ord=inputs["norm_ord"]),
81+
"normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]),
7882
},
7983
}
8084

0 commit comments

Comments
 (0)