Skip to content

Commit 69510e5

Browse files
committed
Merge branch 'main' into dion/demos-to-docs
2 parents 343b05a + 67d918f commit 69510e5

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

examples/cfd/cfd-tesseract/tesseract_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def jvp_jit(
204204
filtered_apply,
205205
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
206206
[tangent_vector],
207-
)
207+
)[1]
208208

209209

210210
@eqx.filter_jit

examples/simple/vectoradd_jax/tesseract_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,16 @@ def apply(inputs: InputSchema) -> OutputSchema:
8686

8787
def abstract_eval(abstract_inputs):
8888
"""Calculate output shape of apply from the shape of its inputs."""
89-
is_shapedtye_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
90-
is_shapedtye_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
89+
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
90+
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
9191

9292
jaxified_inputs = jax.tree.map(
93-
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtye_dict(x) else x,
93+
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
9494
abstract_inputs.model_dump(),
95-
is_leaf=is_shapedtye_dict,
95+
is_leaf=is_shapedtype_dict,
9696
)
9797
dynamic_inputs, static_inputs = eqx.partition(
98-
jaxified_inputs, filter_spec=is_shapedtye_struct
98+
jaxified_inputs, filter_spec=is_shapedtype_struct
9999
)
100100

101101
def wrapped_apply(dynamic_inputs):
@@ -105,10 +105,10 @@ def wrapped_apply(dynamic_inputs):
105105
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
106106
return jax.tree.map(
107107
lambda x: {"shape": x.shape, "dtype": str(x.dtype)}
108-
if is_shapedtye_struct(x)
108+
if is_shapedtype_struct(x)
109109
else x,
110110
jax_shapes,
111-
is_leaf=is_shapedtye_struct,
111+
is_leaf=is_shapedtype_struct,
112112
)
113113

114114

0 commit comments

Comments
 (0)