Skip to content

Commit 0178442

Browse files
committed
ensure gradients of example tesseract are well-behaved
1 parent ea59d54 commit 0178442

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
6262

6363
vector_sum_grad = jax.grad(vector_sum)
6464
vector_sum_grad(x, y)
65-
```
65+
```
6666

6767
> [!TIP]
6868
> Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX.

examples/simple/vectoradd_jax/tesseract_api.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,20 @@ def apply_jit(inputs: dict) -> dict:
6565
b_scaled = inputs["b"]["s"] * inputs["b"]["v"]
6666
add_result = a_scaled + b_scaled
6767
min_result = a_scaled - b_scaled
68+
69+
def safe_norm(x, ord):
70+
# Compute the norm of a vector, adding a small epsilon to ensure
71+
# differentiability and avoid division by zero
72+
return jnp.power(jnp.power(x, ord).sum() + 1e-8, 1 / ord)
73+
6874
return {
6975
"vector_add": {
7076
"result": add_result,
71-
"normed_result": add_result
72-
/ jnp.linalg.norm(add_result, ord=inputs["norm_ord"]),
77+
"normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]),
7378
},
7479
"vector_min": {
7580
"result": min_result,
76-
"normed_result": min_result
77-
/ jnp.linalg.norm(min_result, ord=inputs["norm_ord"]),
81+
"normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]),
7882
},
7983
}
8084

0 commit comments

Comments
 (0)