Skip to content

Commit bffc041

Browse files
zmheikodionhaefner
andauthored
refactor: applied jax recipe to cfd tesseract (#10)
#### Relevant issue or PR N/A #### Description of changes - Refactored cfd Tesseract according to jax recipe - Bumped jax version of cfd Tesseract to 0.6.0 #### Testing done Manual: Confirmed that filtering of static args works and vjp is not recompiled #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract-jax/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Heiko Zimmermann [email protected] --------- Co-authored-by: Dion Häfner <[email protected]>
1 parent 20355bc commit bffc041

File tree

4 files changed

+416
-379
lines changed

4 files changed

+416
-379
lines changed
Lines changed: 126 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from functools import partial
4+
from typing import Any
55

6+
import equinox as eqx
67
import jax
78
import jax.numpy as jnp
89
import jax_cfd.base as cfd
910
from pydantic import BaseModel, Field
10-
from tesseract_core.runtime import Array, Differentiable, Float32, ShapeDType
11-
12-
# TODO: !! Use JAX recipe for this, to avoid re-jitting of VJPs etc. !!
11+
from tesseract_core.runtime import Array, Differentiable, Float32
12+
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths
1313

1414

1515
class InputSchema(BaseModel):
@@ -22,45 +22,25 @@ class InputSchema(BaseModel):
2222
),
2323
Float32,
2424
]
25-
] = Field(description="3D Array defining the initial velocity field [...]")
25+
] = Field(description="3D Array defining the initial velocity field")
2626
density: float = Field(description="Density of the fluid")
2727
viscosity: float = Field(description="Viscosity of the fluid")
28-
inner_steps: float = Field(
28+
inner_steps: int = Field(
2929
description="Number of solver steps for each timestep", default=25
3030
)
31-
outer_steps: float = Field(description="Number of timesteps steps", default=10)
31+
outer_steps: int = Field(description="Number of timesteps steps", default=10)
3232
max_velocity: float = Field(description="Maximum velocity", default=2.0)
3333
cfl_safety_factor: float = Field(description="CFL safety factor", default=0.5)
3434
domain_size_x: float = Field(description="Domain size x", default=1.0)
3535
domain_size_y: float = Field(description="Domain size y", default=1.0)
3636

3737

3838
class OutputSchema(BaseModel):
39-
result: Differentiable[
40-
Array[
41-
(
42-
None,
43-
None,
44-
None,
45-
),
46-
Float32,
47-
]
48-
] = Field(description="3D Array defining the final velocity field [...]")
49-
50-
51-
@partial(
52-
jax.jit,
53-
static_argnames=(
54-
"density",
55-
"viscosity",
56-
"inner_steps",
57-
"outer_steps",
58-
"max_velocity",
59-
"cfl_safety_factor",
60-
"domain_size_x",
61-
"domain_size_y",
62-
),
63-
)
39+
result: Differentiable[Array[(None, None, None), Float32]] = Field(
40+
description="3D Array defining the final velocity field"
41+
)
42+
43+
6444
def cfd_fwd(
6545
v0: jnp.ndarray,
6646
density: float,
@@ -72,6 +52,22 @@ def cfd_fwd(
7252
domain_size_x: float,
7353
domain_size_y: float,
7454
) -> tuple[jax.Array, jax.Array]:
55+
"""Compute the final velocity field using the semi-implicit Navier-Stokes equations.
56+
57+
Args:
58+
v0: Initial velocity field.
59+
density: Density of the fluid.
60+
viscosity: Viscosity of the fluid.
61+
inner_steps: Number of solver steps for each timestep.
62+
outer_steps: Number of timesteps steps.
63+
max_velocity: Maximum velocity.
64+
cfl_safety_factor: CFL safety factor.
65+
domain_size_x: Domain size in x direction.
66+
domain_size_y: Domain size in y direction.
67+
68+
Returns:
69+
Final velocity field.
70+
"""
7571
vx0 = v0[..., 0]
7672
vy0 = v0[..., 1]
7773
bc = cfd.boundaries.HomogeneousBoundaryConditions(
@@ -89,7 +85,7 @@ def cfd_fwd(
8985
vx0 = cfd.grids.GridArray(vx0, grid=grid, offset=(1.0, 0.5))
9086
vy0 = cfd.grids.GridArray(vy0, grid=grid, offset=(0.5, 1.0))
9187

92-
# reconstrut GridVariable from input
88+
# reconstruct GridVariable from input
9389
vx0 = cfd.grids.GridVariable(vx0, bc)
9490
vy0 = cfd.grids.GridVariable(vy0, bc)
9591
v0 = (vx0, vy0)
@@ -106,80 +102,120 @@ def cfd_fwd(
106102
),
107103
steps=inner_steps,
108104
)
109-
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
105+
rollout_fn = cfd.funcutils.trajectory(step_fn, outer_steps)
110106
_, trajectory = jax.device_get(rollout_fn(v0))
111-
112107
vxn = trajectory[0].array.data[-1]
113-
114108
vyn = trajectory[1].array.data[-1]
115-
116109
return jnp.stack([vxn, vyn], axis=-1)
117110

118111

119-
def apply(inputs: InputSchema) -> OutputSchema: #
120-
vn = cfd_fwd(
121-
v0=inputs.v0,
122-
density=inputs.density,
123-
viscosity=inputs.viscosity,
124-
inner_steps=inputs.inner_steps,
125-
outer_steps=inputs.outer_steps,
126-
max_velocity=inputs.max_velocity,
127-
cfl_safety_factor=inputs.cfl_safety_factor,
128-
domain_size_x=inputs.domain_size_x,
129-
domain_size_y=inputs.domain_size_y,
130-
)
112+
@eqx.filter_jit
113+
def apply_jit(inputs: dict) -> dict:
114+
vn = cfd_fwd(**inputs)
115+
return dict(result=vn)
131116

132-
return OutputSchema(result=vn)
133117

118+
def apply(inputs: InputSchema) -> OutputSchema:
119+
return apply_jit(inputs.model_dump())
120+
121+
122+
def jacobian(
123+
inputs: InputSchema,
124+
jac_inputs: set[str],
125+
jac_outputs: set[str],
126+
):
127+
return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs))
134128

135-
def abstract_eval(abstract_inputs):
136-
"""Calculate output shape of apply from the shape of its inputs."""
137-
return {
138-
"result": ShapeDType(shape=abstract_inputs.v0.shape, dtype="float32"),
139-
}
129+
130+
def jacobian_vector_product(
131+
inputs: InputSchema,
132+
jvp_inputs: set[str],
133+
jvp_outputs: set[str],
134+
tangent_vector: dict[str, Any],
135+
):
136+
return jvp_jit(
137+
inputs.model_dump(),
138+
tuple(jvp_inputs),
139+
tuple(jvp_outputs),
140+
tangent_vector,
141+
)
140142

141143

142144
def vector_jacobian_product(
143145
inputs: InputSchema,
144146
vjp_inputs: set[str],
145147
vjp_outputs: set[str],
146-
cotangent_vector,
148+
cotangent_vector: dict[str, Any],
147149
):
148-
signature = [
149-
"v0",
150-
"density",
151-
"viscosity",
152-
"inner_steps",
153-
"outer_steps",
154-
"max_velocity",
155-
"cfl_safety_factor",
156-
"domain_size_x",
157-
"domain_size_y",
158-
]
159-
# We need to do this, rather than just use jvp inputs, as the order in jvp_inputs
160-
# is not necessarily the same as the ordering of the args in the function signature.
161-
static_args = [arg for arg in signature if arg not in vjp_inputs]
162-
nonstatic_args = [arg for arg in signature if arg in vjp_inputs]
163-
164-
def cfd_fwd_reordered(*args, **kwargs):
165-
return cfd_fwd(
166-
**{**{arg: args[i] for i, arg in enumerate(nonstatic_args)}, **kwargs}
167-
)
150+
return vjp_jit(
151+
inputs.model_dump(),
152+
tuple(vjp_inputs),
153+
tuple(vjp_outputs),
154+
cotangent_vector,
155+
)
168156

169-
out = {}
170-
if "result" in vjp_outputs:
171-
# Make the function depend only on nonstatic args, as jax.jvp
172-
# differentiates w.r.t. all free arguments.
173-
func = partial(
174-
cfd_fwd_reordered, **{arg: getattr(inputs, arg) for arg in static_args}
175-
)
176157

177-
_, vjp_func = jax.vjp(
178-
func, *tuple(inputs.model_dump(include=vjp_inputs).values())
179-
)
158+
def abstract_eval(abstract_inputs):
159+
"""Calculate output shape of apply from the shape of its inputs."""
160+
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
161+
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
162+
163+
jaxified_inputs = jax.tree.map(
164+
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
165+
abstract_inputs.model_dump(),
166+
is_leaf=is_shapedtype_dict,
167+
)
168+
dynamic_inputs, static_inputs = eqx.partition(
169+
jaxified_inputs, filter_spec=is_shapedtype_struct
170+
)
171+
172+
def wrapped_apply(dynamic_inputs):
173+
inputs = eqx.combine(static_inputs, dynamic_inputs)
174+
return apply_jit(inputs)
175+
176+
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
177+
return jax.tree.map(
178+
lambda x: (
179+
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x
180+
),
181+
jax_shapes,
182+
is_leaf=is_shapedtype_struct,
183+
)
184+
185+
186+
@eqx.filter_jit
187+
def jac_jit(
188+
inputs: dict,
189+
jac_inputs: tuple[str],
190+
jac_outputs: tuple[str],
191+
):
192+
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
193+
return jax.jacrev(filtered_apply)(
194+
flatten_with_paths(inputs, include_paths=jac_inputs)
195+
)
180196

181-
vals = vjp_func(cotangent_vector["result"])
182-
for arg, val in zip(nonstatic_args, vals, strict=False):
183-
out[arg] = out.get(arg, 0.0) + val
184197

185-
return out
198+
@eqx.filter_jit
199+
def jvp_jit(
200+
inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict
201+
):
202+
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
203+
return jax.jvp(
204+
filtered_apply,
205+
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
206+
[tangent_vector],
207+
)
208+
209+
210+
@eqx.filter_jit
211+
def vjp_jit(
212+
inputs: dict,
213+
vjp_inputs: tuple[str],
214+
vjp_outputs: tuple[str],
215+
cotangent_vector: dict,
216+
):
217+
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
218+
_, vjp_func = jax.vjp(
219+
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
220+
)
221+
return vjp_func(cotangent_vector)[0]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy==1.26.4
22
jax-cfd==0.2.1
3-
jax[cpu]==0.4.34
3+
jax[cpu]==0.6.0
4+
equinox

0 commit comments

Comments
 (0)