11# Copyright 2025 Pasteur Labs. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- from functools import partial
4+ from typing import Any
55
6+ import equinox as eqx
67import jax
78import jax .numpy as jnp
89import jax_cfd .base as cfd
910from pydantic import BaseModel , Field
10- from tesseract_core .runtime import Array , Differentiable , Float32 , ShapeDType
11-
12- # TODO: !! Use JAX recipe for this, to avoid re-jitting of VJPs etc. !!
11+ from tesseract_core .runtime import Array , Differentiable , Float32
12+ from tesseract_core .runtime .tree_transforms import filter_func , flatten_with_paths
1313
1414
1515class InputSchema (BaseModel ):
@@ -22,45 +22,25 @@ class InputSchema(BaseModel):
2222 ),
2323 Float32 ,
2424 ]
25- ] = Field (description = "3D Array defining the initial velocity field [...] " )
25+ ] = Field (description = "3D Array defining the initial velocity field" )
2626 density : float = Field (description = "Density of the fluid" )
2727 viscosity : float = Field (description = "Viscosity of the fluid" )
28- inner_steps : float = Field (
28+ inner_steps : int = Field (
2929 description = "Number of solver steps for each timestep" , default = 25
3030 )
31- outer_steps : float = Field (description = "Number of timesteps steps" , default = 10 )
31+ outer_steps : int = Field (description = "Number of timesteps steps" , default = 10 )
3232 max_velocity : float = Field (description = "Maximum velocity" , default = 2.0 )
3333 cfl_safety_factor : float = Field (description = "CFL safety factor" , default = 0.5 )
3434 domain_size_x : float = Field (description = "Domain size x" , default = 1.0 )
3535 domain_size_y : float = Field (description = "Domain size y" , default = 1.0 )
3636
3737
3838class OutputSchema (BaseModel ):
39- result : Differentiable [
40- Array [
41- (
42- None ,
43- None ,
44- None ,
45- ),
46- Float32 ,
47- ]
48- ] = Field (description = "3D Array defining the final velocity field [...]" )
49-
50-
51- @partial (
52- jax .jit ,
53- static_argnames = (
54- "density" ,
55- "viscosity" ,
56- "inner_steps" ,
57- "outer_steps" ,
58- "max_velocity" ,
59- "cfl_safety_factor" ,
60- "domain_size_x" ,
61- "domain_size_y" ,
62- ),
63- )
39+ result : Differentiable [Array [(None , None , None ), Float32 ]] = Field (
40+ description = "3D Array defining the final velocity field"
41+ )
42+
43+
6444def cfd_fwd (
6545 v0 : jnp .ndarray ,
6646 density : float ,
@@ -72,6 +52,22 @@ def cfd_fwd(
7252 domain_size_x : float ,
7353 domain_size_y : float ,
7454) -> tuple [jax .Array , jax .Array ]:
55+ """Compute the final velocity field using the semi-implicit Navier-Stokes equations.
56+
57+ Args:
58+ v0: Initial velocity field.
59+ density: Density of the fluid.
60+ viscosity: Viscosity of the fluid.
61+ inner_steps: Number of solver steps for each timestep.
62+ outer_steps: Number of timesteps steps.
63+ max_velocity: Maximum velocity.
64+ cfl_safety_factor: CFL safety factor.
65+ domain_size_x: Domain size in x direction.
66+ domain_size_y: Domain size in y direction.
67+
68+ Returns:
69+ Final velocity field.
70+ """
7571 vx0 = v0 [..., 0 ]
7672 vy0 = v0 [..., 1 ]
7773 bc = cfd .boundaries .HomogeneousBoundaryConditions (
@@ -89,7 +85,7 @@ def cfd_fwd(
8985 vx0 = cfd .grids .GridArray (vx0 , grid = grid , offset = (1.0 , 0.5 ))
9086 vy0 = cfd .grids .GridArray (vy0 , grid = grid , offset = (0.5 , 1.0 ))
9187
92- # reconstrut GridVariable from input
88+ # reconstruct GridVariable from input
9389 vx0 = cfd .grids .GridVariable (vx0 , bc )
9490 vy0 = cfd .grids .GridVariable (vy0 , bc )
9591 v0 = (vx0 , vy0 )
@@ -106,80 +102,120 @@ def cfd_fwd(
106102 ),
107103 steps = inner_steps ,
108104 )
109- rollout_fn = jax . jit ( cfd .funcutils .trajectory (step_fn , outer_steps ) )
105+ rollout_fn = cfd .funcutils .trajectory (step_fn , outer_steps )
110106 _ , trajectory = jax .device_get (rollout_fn (v0 ))
111-
112107 vxn = trajectory [0 ].array .data [- 1 ]
113-
114108 vyn = trajectory [1 ].array .data [- 1 ]
115-
116109 return jnp .stack ([vxn , vyn ], axis = - 1 )
117110
118111
119- def apply (inputs : InputSchema ) -> OutputSchema : #
120- vn = cfd_fwd (
121- v0 = inputs .v0 ,
122- density = inputs .density ,
123- viscosity = inputs .viscosity ,
124- inner_steps = inputs .inner_steps ,
125- outer_steps = inputs .outer_steps ,
126- max_velocity = inputs .max_velocity ,
127- cfl_safety_factor = inputs .cfl_safety_factor ,
128- domain_size_x = inputs .domain_size_x ,
129- domain_size_y = inputs .domain_size_y ,
130- )
112+ @eqx .filter_jit
113+ def apply_jit (inputs : dict ) -> dict :
114+ vn = cfd_fwd (** inputs )
115+ return dict (result = vn )
131116
132- return OutputSchema (result = vn )
133117
118+ def apply (inputs : InputSchema ) -> OutputSchema :
119+ return apply_jit (inputs .model_dump ())
134120
135- def abstract_eval (abstract_inputs ):
136- """Calculate output shape of apply from the shape of its inputs."""
137- return {
138- "result" : ShapeDType (shape = abstract_inputs .v0 .shape , dtype = "float32" ),
139- }
121+
122+ def jacobian (
123+ inputs : InputSchema ,
124+ jac_inputs : set [str ],
125+ jac_outputs : set [str ],
126+ ):
127+ return jac_jit (inputs .model_dump (), tuple (jac_inputs ), tuple (jac_outputs ))
128+
129+
130+ def jacobian_vector_product (
131+ inputs : InputSchema ,
132+ jvp_inputs : set [str ],
133+ jvp_outputs : set [str ],
134+ tangent_vector : dict [str , Any ],
135+ ):
136+ return jvp_jit (
137+ inputs .model_dump (),
138+ tuple (jvp_inputs ),
139+ tuple (jvp_outputs ),
140+ tangent_vector ,
141+ )
140142
141143
142144def vector_jacobian_product (
143145 inputs : InputSchema ,
144146 vjp_inputs : set [str ],
145147 vjp_outputs : set [str ],
146- cotangent_vector ,
148+ cotangent_vector : dict [ str , Any ] ,
147149):
148- signature = [
149- "v0" ,
150- "density" ,
151- "viscosity" ,
152- "inner_steps" ,
153- "outer_steps" ,
154- "max_velocity" ,
155- "cfl_safety_factor" ,
156- "domain_size_x" ,
157- "domain_size_y" ,
158- ]
159- # We need to do this, rather than just use jvp inputs, as the order in jvp_inputs
160- # is not necessarily the same as the ordering of the args in the function signature.
161- static_args = [arg for arg in signature if arg not in vjp_inputs ]
162- nonstatic_args = [arg for arg in signature if arg in vjp_inputs ]
163-
164- def cfd_fwd_reordered (* args , ** kwargs ):
165- return cfd_fwd (
166- ** {** {arg : args [i ] for i , arg in enumerate (nonstatic_args )}, ** kwargs }
167- )
150+ return vjp_jit (
151+ inputs .model_dump (),
152+ tuple (vjp_inputs ),
153+ tuple (vjp_outputs ),
154+ cotangent_vector ,
155+ )
168156
169- out = {}
170- if "result" in vjp_outputs :
171- # Make the function depend only on nonstatic args, as jax.jvp
172- # differentiates w.r.t. all free arguments.
173- func = partial (
174- cfd_fwd_reordered , ** {arg : getattr (inputs , arg ) for arg in static_args }
175- )
176157
177- _ , vjp_func = jax .vjp (
178- func , * tuple (inputs .model_dump (include = vjp_inputs ).values ())
179- )
158+ def abstract_eval (abstract_inputs ):
159+ """Calculate output shape of apply from the shape of its inputs."""
160+ is_shapedtype_dict = lambda x : type (x ) is dict and (x .keys () == {"shape" , "dtype" })
161+ is_shapedtype_struct = lambda x : isinstance (x , jax .ShapeDtypeStruct )
162+
163+ jaxified_inputs = jax .tree .map (
164+ lambda x : jax .ShapeDtypeStruct (** x ) if is_shapedtype_dict (x ) else x ,
165+ abstract_inputs .model_dump (),
166+ is_leaf = is_shapedtype_dict ,
167+ )
168+ dynamic_inputs , static_inputs = eqx .partition (
169+ jaxified_inputs , filter_spec = is_shapedtype_struct
170+ )
180171
181- vals = vjp_func ( cotangent_vector [ "result" ])
182- for arg , val in zip ( nonstatic_args , vals , strict = False ):
183- out [ arg ] = out . get ( arg , 0.0 ) + val
172+ def wrapped_apply ( dynamic_inputs ):
173+ inputs = eqx . combine ( static_inputs , dynamic_inputs )
174+ return apply_jit ( inputs )
184175
185- return out
176+ jax_shapes = jax .eval_shape (wrapped_apply , dynamic_inputs )
177+ return jax .tree .map (
178+ lambda x : (
179+ {"shape" : x .shape , "dtype" : str (x .dtype )} if is_shapedtype_struct (x ) else x
180+ ),
181+ jax_shapes ,
182+ is_leaf = is_shapedtype_struct ,
183+ )
184+
185+
186+ @eqx .filter_jit
187+ def jac_jit (
188+ inputs : dict ,
189+ jac_inputs : tuple [str ],
190+ jac_outputs : tuple [str ],
191+ ):
192+ filtered_apply = filter_func (apply_jit , inputs , jac_outputs )
193+ return jax .jacrev (filtered_apply )(
194+ flatten_with_paths (inputs , include_paths = jac_inputs )
195+ )
196+
197+
198+ @eqx .filter_jit
199+ def jvp_jit (
200+ inputs : dict , jvp_inputs : tuple [str ], jvp_outputs : tuple [str ], tangent_vector : dict
201+ ):
202+ filtered_apply = filter_func (apply_jit , inputs , jvp_outputs )
203+ return jax .jvp (
204+ filtered_apply ,
205+ [flatten_with_paths (inputs , include_paths = jvp_inputs )],
206+ [tangent_vector ],
207+ )[1 ]
208+
209+
210+ @eqx .filter_jit
211+ def vjp_jit (
212+ inputs : dict ,
213+ vjp_inputs : tuple [str ],
214+ vjp_outputs : tuple [str ],
215+ cotangent_vector : dict ,
216+ ):
217+ filtered_apply = filter_func (apply_jit , inputs , vjp_outputs )
218+ _ , vjp_func = jax .vjp (
219+ filtered_apply , flatten_with_paths (inputs , include_paths = vjp_inputs )
220+ )
221+ return vjp_func (cotangent_vector )[0 ]
0 commit comments