diff --git a/demo/simple/vectoradd_jax/tesseract_api.py b/demo/simple/vectoradd_jax/tesseract_api.py index 9aecff7..5451902 100644 --- a/demo/simple/vectoradd_jax/tesseract_api.py +++ b/demo/simple/vectoradd_jax/tesseract_api.py @@ -86,16 +86,16 @@ def apply(inputs: InputSchema) -> OutputSchema: def abstract_eval(abstract_inputs): """Calculate output shape of apply from the shape of its inputs.""" - is_shapedtye_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) - is_shapedtye_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) + is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) + is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) jaxified_inputs = jax.tree.map( - lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtye_dict(x) else x, + lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x, abstract_inputs.model_dump(), - is_leaf=is_shapedtye_dict, + is_leaf=is_shapedtype_dict, ) dynamic_inputs, static_inputs = eqx.partition( - jaxified_inputs, filter_spec=is_shapedtye_struct + jaxified_inputs, filter_spec=is_shapedtype_struct ) def wrapped_apply(dynamic_inputs): @@ -105,10 +105,10 @@ def wrapped_apply(dynamic_inputs): jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs) return jax.tree.map( lambda x: {"shape": x.shape, "dtype": str(x.dtype)} - if is_shapedtye_struct(x) + if is_shapedtype_struct(x) else x, jax_shapes, - is_leaf=is_shapedtye_struct, + is_leaf=is_shapedtype_struct, )