Skip to content

Commit 7216e26

Browse files
committed
working
1 parent 3878c39 commit 7216e26

File tree

3 files changed

+25
-23
lines changed

3 files changed

+25
-23
lines changed

examples/simple/partial/tesseract_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def apply(inputs: InputSchema) -> OutputSchema:
3131
}
3232

3333

34-
# def vector_jacobian_product(
35-
# inputs: InputSchema,
36-
# vjp_inputs: set[str],
37-
# vjp_outputs: set[str],
38-
# cotangent_vector: dict[str, Any],
39-
# ):
34+
def vector_jacobian_product(
35+
inputs: InputSchema,
36+
vjp_inputs: set[str],
37+
vjp_outputs: set[str],
38+
cotangent_vector: dict[str, Any],
39+
):
4040

41-
# return {
42-
# "a": 2.0 * cotangent_vector["b"],
43-
# }
41+
return {
42+
"a": 2.0 * cotangent_vector["b"],
43+
}
4444

4545

4646

tesseract_jax/tesseract_compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def apply(
159159

160160
out_data = self.client.apply(inputs)
161161

162-
out_data, output_pytreedef = jax.tree.flatten(out_data)
163-
return tuple(out_data), output_pytreedef
162+
out_data = tuple(jax.tree.flatten(out_data)[0])
163+
return out_data
164164

165165

166166
def apply_pytree(

test.ipynb

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 4,
5+
"execution_count": 1,
66
"id": "4317d6a5",
77
"metadata": {},
88
"outputs": [
99
{
10-
"ename": "XlaRuntimeError",
11-
"evalue": "INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel_launcher.py\", line 18, in <module>\n File \"/home/ar/.local/lib/python3.12/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelapp.py\", line 739, in start\n File \"/home/ar/.local/lib/python3.12/site-packages/tornado/platform/asyncio.py\", line 211, in start\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/asyncio/base_events.py\", line 645, in run_forever\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/asyncio/base_events.py\", line 1999, in _run_once\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/asyncio/events.py\", line 88, in _run\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 519, in dispatch_queue\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 508, in process_one\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 400, in dispatch_shell\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 368, in execute_request\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 455, in do_execute\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/zmqshell.py\", line 577, in run_cell\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3116, in run_cell\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3171, in _run_cell\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/async_helpers.py\", line 128, in _pseudo_sync_runner\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3394, in run_cell_async\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3639, in run_ast_nodes\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3699, in run_code\n File \"/tmp/ipykernel_10673/330698556.py\", line 13, in <module>\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/traceback_util.py\", line 180, in reraise_with_filtered_traceback\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/pjit.py\", line 270, in cache_miss\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/pjit.py\", line 149, in _python_pjit_helper\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/pjit.py\", line 1804, in _pjit_call_impl_python\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/profiler.py\", line 364, in wrapper\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py\", line 1353, in __call__\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/callback.py\", line 793, in _wrapped_callback\nRuntimeError: Incorrect output shape for return value #0: Expected: (3,), Actual: (2, 3)",
12-
"output_type": "error",
13-
"traceback": [
14-
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
15-
"\u001b[31mXlaRuntimeError\u001b[39m Traceback (most recent call last)",
16-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 13\u001b[39m\n\u001b[32m 8\u001b[39m vectoradd = Tesseract.from_tesseract_api(\u001b[33m\"\u001b[39m\u001b[33mexamples/simple/partial/tesseract_api.py\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 11\u001b[39m input_dict = {\u001b[33m\"\u001b[39m\u001b[33ma\u001b[39m\u001b[33m\"\u001b[39m: jnp.array([\u001b[32m1.0\u001b[39m, \u001b[32m2.0\u001b[39m, \u001b[32m3.0\u001b[39m], dtype=\u001b[33m\"\u001b[39m\u001b[33mfloat32\u001b[39m\u001b[33m\"\u001b[39m)}\n\u001b[32m---> \u001b[39m\u001b[32m13\u001b[39m outputs = \u001b[43mjax\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mapply_tesseract\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvectoradd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minput_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# outputs = apply_tesseract(vectoradd, inputs=input_dict)\u001b[39;00m\n\u001b[32m 15\u001b[39m pprint(outputs)\n",
17-
" \u001b[31m[... skipping hidden 5 frame]\u001b[39m\n",
18-
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1353\u001b[39m, in \u001b[36mExecuteReplicated.__call__\u001b[39m\u001b[34m(self, *args)\u001b[39m\n\u001b[32m 1350\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mself\u001b[39m.ordered_effects \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.has_unordered_effects\n\u001b[32m 1351\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.has_host_callbacks):\n\u001b[32m 1352\u001b[39m input_bufs = \u001b[38;5;28mself\u001b[39m._add_tokens_to_inputs(input_bufs)\n\u001b[32m-> \u001b[39m\u001b[32m1353\u001b[39m results = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mxla_executable\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexecute_sharded\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1354\u001b[39m \u001b[43m \u001b[49m\u001b[43minput_bufs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwith_tokens\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[32m 1355\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1357\u001b[39m result_token_bufs = results.disassemble_prefix_into_single_device_arrays(\n\u001b[32m 1358\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m.ordered_effects))\n\u001b[32m 1359\u001b[39m sharded_runtime_token = results.consume_token()\n",
19-
"\u001b[31mXlaRuntimeError\u001b[39m: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel_launcher.py\", line 18, in <module>\n File \"/home/ar/.local/lib/python3.12/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelapp.py\", line 739, in start\n File \"/home/ar/.local/lib/python3.12/site-packages/tornado/platform/asyncio.py\", line 211, in start\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/asyncio/base_events.py\", line 645, in run_forever\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/asyncio/base_events.py\", line 1999, in _run_once\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/asyncio/events.py\", line 88, in _run\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 519, in dispatch_queue\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 508, in process_one\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 400, in dispatch_shell\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 368, in execute_request\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 455, in do_execute\n File \"/home/ar/.local/lib/python3.12/site-packages/ipykernel/zmqshell.py\", line 577, in run_cell\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3116, in run_cell\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3171, in _run_cell\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/async_helpers.py\", line 128, in _pseudo_sync_runner\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3394, in run_cell_async\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3639, in run_ast_nodes\n File \"/home/ar/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3699, in run_code\n File \"/tmp/ipykernel_10673/330698556.py\", line 13, in <module>\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/traceback_util.py\", line 180, in reraise_with_filtered_traceback\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/pjit.py\", line 270, in cache_miss\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/pjit.py\", line 149, in _python_pjit_helper\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/pjit.py\", line 1804, in _pjit_call_impl_python\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/profiler.py\", line 364, in wrapper\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py\", line 1353, in __call__\n File \"/home/ar/miniconda3/envs/tessj/lib/python3.12/site-packages/jax/_src/callback.py\", line 793, in _wrapped_callback\nRuntimeError: Incorrect output shape for return value #0: Expected: (3,), Actual: (2, 3)"
10+
"name": "stderr",
11+
"output_type": "stream",
12+
"text": [
13+
"WARNING:2025-09-18 09:34:00,243:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
14+
]
15+
},
16+
{
17+
"name": "stdout",
18+
"output_type": "stream",
19+
"text": [
20+
"{'b': Array([2., 4., 6.], dtype=float32),\n",
21+
" 'c': Array([1., 1., 1.], dtype=float32)}\n"
2022
]
2123
}
2224
],
@@ -40,7 +42,7 @@
4042
},
4143
{
4244
"cell_type": "code",
43-
"execution_count": 13,
45+
"execution_count": 2,
4446
"id": "1adb5855",
4547
"metadata": {},
4648
"outputs": [

0 commit comments

Comments
 (0)