@@ -86,16 +86,16 @@ def apply(inputs: InputSchema) -> OutputSchema:
8686
8787def abstract_eval (abstract_inputs ):
8888 """Calculate output shape of apply from the shape of its inputs."""
89- is_shapedtye_dict = lambda x : type (x ) is dict and (x .keys () == {"shape" , "dtype" })
90- is_shapedtye_struct = lambda x : isinstance (x , jax .ShapeDtypeStruct )
89+ is_shapedtype_dict = lambda x : type (x ) is dict and (x .keys () == {"shape" , "dtype" })
90+ is_shapedtype_struct = lambda x : isinstance (x , jax .ShapeDtypeStruct )
9191
9292 jaxified_inputs = jax .tree .map (
93- lambda x : jax .ShapeDtypeStruct (** x ) if is_shapedtye_dict (x ) else x ,
93+ lambda x : jax .ShapeDtypeStruct (** x ) if is_shapedtype_dict (x ) else x ,
9494 abstract_inputs .model_dump (),
95- is_leaf = is_shapedtye_dict ,
95+ is_leaf = is_shapedtype_dict ,
9696 )
9797 dynamic_inputs , static_inputs = eqx .partition (
98- jaxified_inputs , filter_spec = is_shapedtye_struct
98+ jaxified_inputs , filter_spec = is_shapedtype_struct
9999 )
100100
101101 def wrapped_apply (dynamic_inputs ):
@@ -105,10 +105,10 @@ def wrapped_apply(dynamic_inputs):
105105 jax_shapes = jax .eval_shape (wrapped_apply , dynamic_inputs )
106106 return jax .tree .map (
107107 lambda x : {"shape" : x .shape , "dtype" : str (x .dtype )}
108- if is_shapedtye_struct (x )
108+ if is_shapedtype_struct (x )
109109 else x ,
110110 jax_shapes ,
111- is_leaf = is_shapedtye_struct ,
111+ is_leaf = is_shapedtype_struct ,
112112 )
113113
114114
0 commit comments