From 01784421dc61d4738beb6964dc0acfbe49a61f0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Wed, 23 Apr 2025 21:20:22 +0200 Subject: [PATCH 1/2] ensure gradients of example tesseract are well-behaved --- README.md | 2 +- examples/simple/vectoradd_jax/tesseract_api.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b179849..2aa8f71 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser vector_sum_grad = jax.grad(vector_sum) vector_sum_grad(x, y) - ``` + ``` > [!TIP] > Now you're ready to jump into our [examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) for more ways to use Tesseract-JAX. diff --git a/examples/simple/vectoradd_jax/tesseract_api.py b/examples/simple/vectoradd_jax/tesseract_api.py index 5451902..e9a6b44 100644 --- a/examples/simple/vectoradd_jax/tesseract_api.py +++ b/examples/simple/vectoradd_jax/tesseract_api.py @@ -65,16 +65,20 @@ def apply_jit(inputs: dict) -> dict: b_scaled = inputs["b"]["s"] * inputs["b"]["v"] add_result = a_scaled + b_scaled min_result = a_scaled - b_scaled + + def safe_norm(x, ord): + # Compute the norm of a vector, adding a small epsilon to ensure + # differentiability and avoid division by zero + return jnp.power(jnp.power(x, ord).sum() + 1e-8, 1 / ord) + return { "vector_add": { "result": add_result, - "normed_result": add_result - / jnp.linalg.norm(add_result, ord=inputs["norm_ord"]), + "normed_result": add_result / safe_norm(add_result, ord=inputs["norm_ord"]), }, "vector_min": { "result": min_result, - "normed_result": min_result - / jnp.linalg.norm(min_result, ord=inputs["norm_ord"]), + "normed_result": min_result / safe_norm(min_result, ord=inputs["norm_ord"]), }, } From f5c2946fdf3d4a40a533080b8ecd92b695f67a31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Wed, 23 Apr 2025 21:46:03 +0200 Subject: [PATCH 2/2] add abs for odd powers --- examples/simple/vectoradd_jax/tesseract_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple/vectoradd_jax/tesseract_api.py b/examples/simple/vectoradd_jax/tesseract_api.py index e9a6b44..0552264 100644 --- a/examples/simple/vectoradd_jax/tesseract_api.py +++ b/examples/simple/vectoradd_jax/tesseract_api.py @@ -69,7 +69,7 @@ def apply_jit(inputs: dict) -> dict: def safe_norm(x, ord): # Compute the norm of a vector, adding a small epsilon to ensure # differentiability and avoid division by zero - return jnp.power(jnp.power(x, ord).sum() + 1e-8, 1 / ord) + return jnp.power(jnp.power(jnp.abs(x), ord).sum() + 1e-8, 1 / ord) return { "vector_add": {