Skip to content

Commit 6018360

Browse files
committed
flatten unflatten, why
1 parent 227bead commit 6018360

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

examples/simple/partial/tesseract_api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,20 @@ def apply(inputs: InputSchema) -> OutputSchema:
4444

4545

4646

47+
# def abstract_eval(abstract_inputs):
48+
# """Calculate output shape of apply from the shape of its inputs."""
49+
# return {
50+
# "b": ShapeDType(shape=(None,), dtype="float32"),
51+
# "c": ShapeDType(shape=(None,), dtype="float32"),
52+
# }
53+
54+
55+
56+
4757
def abstract_eval(abstract_inputs):
4858
"""Calculate output shape of apply from the shape of its inputs."""
4959
return {
50-
"b": ShapeDType(shape=(None,), dtype="float32"),
51-
"c": ShapeDType(shape=(None,), dtype="float32"),
60+
"b": ShapeDType(shape=(abstract_inputs.a.shape[0],), dtype="float32"),
61+
"c": ShapeDType(shape=(3,), dtype="float32"),
5262
}
5363

tesseract_jax/tesseract_compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def apply(
159159

160160
out_data = self.client.apply(inputs)
161161

162-
out_data = tuple(jax.tree.flatten(out_data)[0])
163-
return out_data
162+
out_data, output_pytreedef = jax.tree.flatten(out_data)
163+
return tuple(out_data), output_pytreedef
164164

165165
def jacobian_vector_product(
166166
self,

test.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 6,
5+
"execution_count": 3,
66
"id": "4317d6a5",
77
"metadata": {},
88
"outputs": [
99
{
1010
"name": "stdout",
1111
"output_type": "stream",
1212
"text": [
13-
"(array([2., 4., 6.], dtype=float32), array([1., 1., 1.], dtype=float32))\n"
13+
"[Array([2., 4., 6.], dtype=float32), Array([1., 1., 1.], dtype=float32)]\n"
1414
]
1515
}
1616
],
@@ -20,14 +20,15 @@
2020
"from tesseract_core import Tesseract\n",
2121
"from tesseract_jax import apply_tesseract\n",
2222
"import jax.numpy as jnp\n",
23+
"import jax\n",
2324
"\n",
2425
"vectoradd = Tesseract.from_tesseract_api(\"examples/simple/partial/tesseract_api.py\")\n",
2526
"\n",
2627
"\n",
2728
"input_dict = {\"a\": jnp.array([1.0, 2.0, 3.0], dtype=\"float32\")}\n",
2829
"\n",
2930
"\n",
30-
"outputs = apply_tesseract(vectoradd, inputs=input_dict)\n",
31+
"outputs = jax.jit(apply_tesseract)(vectoradd, inputs=input_dict)\n",
3132
"pprint(outputs)"
3233
]
3334
},

0 commit comments

Comments
 (0)