diff --git a/README.md b/README.md index ae5aa81..38a3049 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ jax.grad(vector_sum)(x, y) # 🎉 vector_sum_grad = jax.grad(vector_sum) vector_sum_grad(x, y) - ``` + ``` > [!TIP] > Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX. diff --git a/examples/simple/vectoradd_jax/tesseract_api.py b/examples/simple/vectoradd_jax/tesseract_api.py index 5451902..0552264 100644 --- a/examples/simple/vectoradd_jax/tesseract_api.py +++ b/examples/simple/vectoradd_jax/tesseract_api.py @@ -65,16 +65,20 @@ def apply_jit(inputs: dict) -> dict: b_scaled = inputs["b"]["s"] * inputs["b"]["v"] add_result = a_scaled + b_scaled min_result = a_scaled - b_scaled + + def safe_norm(x, ord): + # Compute the norm of a vector, adding a small epsilon to ensure + # differentiability and avoid division by zero + return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord) + return { "vector_add": { "result": add_result, - "normed_result": add_result - / jnp.linalg.norm(add_result, ord=inputs["norm_ord"]), + "normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]), }, "vector_min": { "result": min_result, - "normed_result": min_result - / jnp.linalg.norm(min_result, ord=inputs["norm_ord"]), + "normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]), }, }