Skip to content

Commit 7f073e2

Browse files
authored
Merge branch 'main' into heiko/fix-typo-vectoradd
2 parents 7f383be + eaa5f33 commit 7f073e2

File tree

4 files changed

+416
-379
lines changed

4 files changed

+416
-379
lines changed
Lines changed: 126 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from functools import partial
4+
from typing import Any
55

6+
import equinox as eqx
67
import jax
78
import jax.numpy as jnp
89
import jax_cfd.base as cfd
910
from pydantic import BaseModel, Field
10-
from tesseract_core.runtime import Array, Differentiable, Float32, ShapeDType
11-
12-
# TODO: !! Use JAX recipe for this, to avoid re-jitting of VJPs etc. !!
11+
from tesseract_core.runtime import Array, Differentiable, Float32
12+
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths
1313

1414

1515
class InputSchema(BaseModel):
@@ -22,45 +22,25 @@ class InputSchema(BaseModel):
2222
),
2323
Float32,
2424
]
25-
] = Field(description="3D Array defining the initial velocity field [...]")
25+
] = Field(description="3D Array defining the initial velocity field")
2626
density: float = Field(description="Density of the fluid")
2727
viscosity: float = Field(description="Viscosity of the fluid")
28-
inner_steps: float = Field(
28+
inner_steps: int = Field(
2929
description="Number of solver steps for each timestep", default=25
3030
)
31-
outer_steps: float = Field(description="Number of timesteps steps", default=10)
31+
outer_steps: int = Field(description="Number of timesteps steps", default=10)
3232
max_velocity: float = Field(description="Maximum velocity", default=2.0)
3333
cfl_safety_factor: float = Field(description="CFL safety factor", default=0.5)
3434
domain_size_x: float = Field(description="Domain size x", default=1.0)
3535
domain_size_y: float = Field(description="Domain size y", default=1.0)
3636

3737

3838
class OutputSchema(BaseModel):
39-
result: Differentiable[
40-
Array[
41-
(
42-
None,
43-
None,
44-
None,
45-
),
46-
Float32,
47-
]
48-
] = Field(description="3D Array defining the final velocity field [...]")
49-
50-
51-
@partial(
52-
jax.jit,
53-
static_argnames=(
54-
"density",
55-
"viscosity",
56-
"inner_steps",
57-
"outer_steps",
58-
"max_velocity",
59-
"cfl_safety_factor",
60-
"domain_size_x",
61-
"domain_size_y",
62-
),
63-
)
39+
result: Differentiable[Array[(None, None, None), Float32]] = Field(
40+
description="3D Array defining the final velocity field"
41+
)
42+
43+
6444
def cfd_fwd(
6545
v0: jnp.ndarray,
6646
density: float,
@@ -72,6 +52,22 @@ def cfd_fwd(
7252
domain_size_x: float,
7353
domain_size_y: float,
7454
) -> tuple[jax.Array, jax.Array]:
55+
"""Compute the final velocity field using the semi-implicit Navier-Stokes equations.
56+
57+
Args:
58+
v0: Initial velocity field.
59+
density: Density of the fluid.
60+
viscosity: Viscosity of the fluid.
61+
inner_steps: Number of solver steps for each timestep.
62+
outer_steps: Number of timesteps steps.
63+
max_velocity: Maximum velocity.
64+
cfl_safety_factor: CFL safety factor.
65+
domain_size_x: Domain size in x direction.
66+
domain_size_y: Domain size in y direction.
67+
68+
Returns:
69+
Final velocity field.
70+
"""
7571
vx0 = v0[..., 0]
7672
vy0 = v0[..., 1]
7773
bc = cfd.boundaries.HomogeneousBoundaryConditions(
@@ -89,7 +85,7 @@ def cfd_fwd(
8985
vx0 = cfd.grids.GridArray(vx0, grid=grid, offset=(1.0, 0.5))
9086
vy0 = cfd.grids.GridArray(vy0, grid=grid, offset=(0.5, 1.0))
9187

92-
# reconstrut GridVariable from input
88+
# reconstruct GridVariable from input
9389
vx0 = cfd.grids.GridVariable(vx0, bc)
9490
vy0 = cfd.grids.GridVariable(vy0, bc)
9591
v0 = (vx0, vy0)
@@ -106,80 +102,120 @@ def cfd_fwd(
106102
),
107103
steps=inner_steps,
108104
)
109-
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
105+
rollout_fn = cfd.funcutils.trajectory(step_fn, outer_steps)
110106
_, trajectory = jax.device_get(rollout_fn(v0))
111-
112107
vxn = trajectory[0].array.data[-1]
113-
114108
vyn = trajectory[1].array.data[-1]
115-
116109
return jnp.stack([vxn, vyn], axis=-1)
117110

118111

119-
def apply(inputs: InputSchema) -> OutputSchema: #
120-
vn = cfd_fwd(
121-
v0=inputs.v0,
122-
density=inputs.density,
123-
viscosity=inputs.viscosity,
124-
inner_steps=inputs.inner_steps,
125-
outer_steps=inputs.outer_steps,
126-
max_velocity=inputs.max_velocity,
127-
cfl_safety_factor=inputs.cfl_safety_factor,
128-
domain_size_x=inputs.domain_size_x,
129-
domain_size_y=inputs.domain_size_y,
130-
)
112+
@eqx.filter_jit
113+
def apply_jit(inputs: dict) -> dict:
114+
vn = cfd_fwd(**inputs)
115+
return dict(result=vn)
131116

132-
return OutputSchema(result=vn)
133117

118+
def apply(inputs: InputSchema) -> OutputSchema:
119+
return apply_jit(inputs.model_dump())
134120

135-
def abstract_eval(abstract_inputs):
136-
"""Calculate output shape of apply from the shape of its inputs."""
137-
return {
138-
"result": ShapeDType(shape=abstract_inputs.v0.shape, dtype="float32"),
139-
}
121+
122+
def jacobian(
123+
inputs: InputSchema,
124+
jac_inputs: set[str],
125+
jac_outputs: set[str],
126+
):
127+
return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs))
128+
129+
130+
def jacobian_vector_product(
131+
inputs: InputSchema,
132+
jvp_inputs: set[str],
133+
jvp_outputs: set[str],
134+
tangent_vector: dict[str, Any],
135+
):
136+
return jvp_jit(
137+
inputs.model_dump(),
138+
tuple(jvp_inputs),
139+
tuple(jvp_outputs),
140+
tangent_vector,
141+
)
140142

141143

142144
def vector_jacobian_product(
143145
inputs: InputSchema,
144146
vjp_inputs: set[str],
145147
vjp_outputs: set[str],
146-
cotangent_vector,
148+
cotangent_vector: dict[str, Any],
147149
):
148-
signature = [
149-
"v0",
150-
"density",
151-
"viscosity",
152-
"inner_steps",
153-
"outer_steps",
154-
"max_velocity",
155-
"cfl_safety_factor",
156-
"domain_size_x",
157-
"domain_size_y",
158-
]
159-
# We need to do this, rather than just use jvp inputs, as the order in jvp_inputs
160-
# is not necessarily the same as the ordering of the args in the function signature.
161-
static_args = [arg for arg in signature if arg not in vjp_inputs]
162-
nonstatic_args = [arg for arg in signature if arg in vjp_inputs]
163-
164-
def cfd_fwd_reordered(*args, **kwargs):
165-
return cfd_fwd(
166-
**{**{arg: args[i] for i, arg in enumerate(nonstatic_args)}, **kwargs}
167-
)
150+
return vjp_jit(
151+
inputs.model_dump(),
152+
tuple(vjp_inputs),
153+
tuple(vjp_outputs),
154+
cotangent_vector,
155+
)
168156

169-
out = {}
170-
if "result" in vjp_outputs:
171-
# Make the function depend only on nonstatic args, as jax.jvp
172-
# differentiates w.r.t. all free arguments.
173-
func = partial(
174-
cfd_fwd_reordered, **{arg: getattr(inputs, arg) for arg in static_args}
175-
)
176157

177-
_, vjp_func = jax.vjp(
178-
func, *tuple(inputs.model_dump(include=vjp_inputs).values())
179-
)
158+
def abstract_eval(abstract_inputs):
159+
"""Calculate output shape of apply from the shape of its inputs."""
160+
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
161+
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
162+
163+
jaxified_inputs = jax.tree.map(
164+
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
165+
abstract_inputs.model_dump(),
166+
is_leaf=is_shapedtype_dict,
167+
)
168+
dynamic_inputs, static_inputs = eqx.partition(
169+
jaxified_inputs, filter_spec=is_shapedtype_struct
170+
)
180171

181-
vals = vjp_func(cotangent_vector["result"])
182-
for arg, val in zip(nonstatic_args, vals, strict=False):
183-
out[arg] = out.get(arg, 0.0) + val
172+
def wrapped_apply(dynamic_inputs):
173+
inputs = eqx.combine(static_inputs, dynamic_inputs)
174+
return apply_jit(inputs)
184175

185-
return out
176+
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
177+
return jax.tree.map(
178+
lambda x: (
179+
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x
180+
),
181+
jax_shapes,
182+
is_leaf=is_shapedtype_struct,
183+
)
184+
185+
186+
@eqx.filter_jit
187+
def jac_jit(
188+
inputs: dict,
189+
jac_inputs: tuple[str],
190+
jac_outputs: tuple[str],
191+
):
192+
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
193+
return jax.jacrev(filtered_apply)(
194+
flatten_with_paths(inputs, include_paths=jac_inputs)
195+
)
196+
197+
198+
@eqx.filter_jit
199+
def jvp_jit(
200+
inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict
201+
):
202+
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
203+
return jax.jvp(
204+
filtered_apply,
205+
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
206+
[tangent_vector],
207+
)[1]
208+
209+
210+
@eqx.filter_jit
211+
def vjp_jit(
212+
inputs: dict,
213+
vjp_inputs: tuple[str],
214+
vjp_outputs: tuple[str],
215+
cotangent_vector: dict,
216+
):
217+
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
218+
_, vjp_func = jax.vjp(
219+
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
220+
)
221+
return vjp_func(cotangent_vector)[0]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy==1.26.4
22
jax-cfd==0.2.1
3-
jax[cpu]==0.4.34
3+
jax[cpu]==0.6.0
4+
equinox

0 commit comments

Comments
 (0)