Skip to content

Commit 7245274

Browse files
committed
reproduce error
1 parent cbc773e commit 7245274

File tree

5 files changed

+98
-0
lines changed

5 files changed

+98
-0
lines changed

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def get_tesseract_folders():
2020
tesseract_folders = [
2121
"univariate_tesseract",
2222
"nested_tesseract",
23+
"non_abstract_tesseract",
24+
"vectoradd_tesseract",
2325
# Add more as needed
2426
]
2527
return tesseract_folders
@@ -86,3 +88,4 @@ def served_tesseract():
8688
served_univariate_tesseract_raw = make_tesseract_fixture("univariate_tesseract")
8789
served_nested_tesseract_raw = make_tesseract_fixture("nested_tesseract")
8890
served_non_abstract_tesseract = make_tesseract_fixture("non_abstract_tesseract")
91+
served_vectoradd_tesseract = make_tesseract_fixture("vectoradd_tesseract")

tests/test_endtoend.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
import jax
6+
import jax.numpy as jnp
67
import numpy as np
78
import pytest
89
from jax.typing import ArrayLike
@@ -562,6 +563,46 @@ def f(a):
562563
_assert_pytree_isequal(result, result_ref)
563564

564565

566+
@pytest.mark.parametrize("use_jit", [True, False])
567+
def test_tesseract_loss(served_vectoradd_tesseract, use_jit):
568+
vectoradd_tess = Tesseract(served_vectoradd_tesseract)
569+
a = np.array([1.0, 2.0, 3.0], dtype="float32")
570+
571+
# b = jax.lax.stop_gradient(b)
572+
573+
def loss_fn(a):
574+
b = np.array([4.0, 5.0, 6.0], dtype="float32")
575+
576+
vectoradd_fn_a: jax.Callable = lambda a: apply_tesseract(
577+
vectoradd_tess,
578+
inputs=dict(
579+
a=a,
580+
b=b,
581+
),
582+
)
583+
584+
c = vectoradd_fn_a(a)["c"]
585+
586+
vectoradd_fn_b: jax.Callable = lambda a: apply_tesseract(
587+
vectoradd_tess,
588+
inputs=dict(
589+
a=a,
590+
b=c,
591+
),
592+
)
593+
594+
outputs = vectoradd_fn_b(a)
595+
596+
return jnp.sum((outputs["c"]) ** 2)
597+
598+
if use_jit:
599+
loss_fn = jax.jit(loss_fn)
600+
601+
value_and_grad_fn = jax.value_and_grad(loss_fn)
602+
603+
assert value_and_grad_fn(a) is not None
604+
605+
565606
def test_non_abstract_tesseract_vjp(served_non_abstract_tesseract):
566607
non_abstract_tess = Tesseract(served_non_abstract_tesseract)
567608

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Pasteur Labs. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from typing import Any
6+
7+
from pydantic import BaseModel, Field
8+
from tesseract_core.runtime import Array, Differentiable, Float32
9+
10+
11+
class InputSchema(BaseModel):
12+
a: Differentiable[Array[(None,), Float32]] = Field(description="Arbitrary vector a")
13+
b: Array[(None,), Float32] = Field(description="Arbitrary vector b")
14+
15+
16+
class OutputSchema(BaseModel):
17+
c: Differentiable[Array[(None,), Float32]] = Field(
18+
description="Vector s_a·a + s_b·b"
19+
)
20+
21+
22+
def apply(inputs: InputSchema) -> OutputSchema:
23+
"""Adds two vectors `a` and `b`."""
24+
return OutputSchema(
25+
c=inputs.a + inputs.b,
26+
)
27+
28+
29+
def abstract_eval(abstract_inputs):
30+
"""Abstract evaluation of the addition operation."""
31+
return {
32+
"c": abstract_inputs.a,
33+
}
34+
35+
36+
def vector_jacobian_product(
37+
inputs: InputSchema,
38+
vjp_inputs: set[str],
39+
vjp_outputs: set[str],
40+
cotangent_vector: dict[str, Any],
41+
):
42+
return {
43+
"a": cotangent_vector["c"],
44+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name: vectoradd_tesseracts
2+
version: "2025-11-05"
3+
description: |
4+
Tesseract that adds two vectors. Uses jax internally.
5+
6+
build_config:
7+
target_platform: "native"
8+
# package_data: []
9+
# custom_build_steps: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
jax[cpu]

0 commit comments

Comments
 (0)