Skip to content

Commit bd73ca5

Browse files
committed
Applied jax recipe to cfd tesseract
1 parent 466fb40 commit bd73ca5

File tree

2 files changed

+127
-109
lines changed

2 files changed

+127
-109
lines changed

demo/cfd/cfd-tesseract/tesseract_api.py

Lines changed: 125 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
# Copyright 2025 Pasteur Labs. All Rights Reserved.
2-
# SPDX-License-Identifier: Apache-2.0
3-
4-
from functools import partial
1+
from typing import Any
52

3+
import equinox as eqx
64
import jax
75
import jax.numpy as jnp
86
import jax_cfd.base as cfd
97
from pydantic import BaseModel, Field
10-
from tesseract_core.runtime import Array, Differentiable, Float32, ShapeDType
11-
12-
# TODO: !! Use JAX recipe for this, to avoid re-jitting of VJPs etc. !!
8+
from tesseract_core.runtime import Array, Differentiable, Float32
9+
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths
1310

1411

1512
class InputSchema(BaseModel):
@@ -25,42 +22,134 @@ class InputSchema(BaseModel):
2522
] = Field(description="3D Array defining the initial velocity field [...]")
2623
density: float = Field(description="Density of the fluid")
2724
viscosity: float = Field(description="Viscosity of the fluid")
28-
inner_steps: float = Field(
25+
inner_steps: int = Field(
2926
description="Number of solver steps for each timestep", default=25
3027
)
31-
outer_steps: float = Field(description="Number of timesteps steps", default=10)
28+
outer_steps: int = Field(description="Number of timesteps steps", default=10)
3229
max_velocity: float = Field(description="Maximum velocity", default=2.0)
3330
cfl_safety_factor: float = Field(description="CFL safety factor", default=0.5)
3431
domain_size_x: float = Field(description="Domain size x", default=1.0)
3532
domain_size_y: float = Field(description="Domain size y", default=1.0)
3633

3734

3835
class OutputSchema(BaseModel):
39-
result: Differentiable[
40-
Array[
41-
(
42-
None,
43-
None,
44-
None,
45-
),
46-
Float32,
47-
]
48-
] = Field(description="3D Array defining the final velocity field [...]")
49-
50-
51-
@partial(
52-
jax.jit,
53-
static_argnames=(
54-
"density",
55-
"viscosity",
56-
"inner_steps",
57-
"outer_steps",
58-
"max_velocity",
59-
"cfl_safety_factor",
60-
"domain_size_x",
61-
"domain_size_y",
62-
),
63-
)
36+
result: Differentiable[Array[(None, None, None), Float32]] = Field(
37+
description="3D Array defining the final velocity field [...]"
38+
)
39+
40+
41+
@eqx.filter_jit
42+
def apply_jit(inputs: dict) -> dict:
43+
vn = cfd_fwd(**inputs)
44+
return dict(result=vn)
45+
46+
47+
def apply(inputs: InputSchema) -> OutputSchema:
48+
return apply_jit(inputs.model_dump())
49+
50+
51+
def jacobian(
52+
inputs: InputSchema,
53+
jac_inputs: set[str],
54+
jac_outputs: set[str],
55+
):
56+
return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs))
57+
58+
59+
def jacobian_vector_product(
60+
inputs: InputSchema,
61+
jvp_inputs: set[str],
62+
jvp_outputs: set[str],
63+
tangent_vector: dict[str, Any],
64+
):
65+
return jvp_jit(
66+
inputs.model_dump(),
67+
tuple(jvp_inputs),
68+
tuple(jvp_outputs),
69+
tangent_vector,
70+
)
71+
72+
73+
def vector_jacobian_product(
74+
inputs: InputSchema,
75+
vjp_inputs: set[str],
76+
vjp_outputs: set[str],
77+
cotangent_vector: dict[str, Any],
78+
):
79+
return vjp_jit(
80+
inputs.model_dump(),
81+
tuple(vjp_inputs),
82+
tuple(vjp_outputs),
83+
cotangent_vector,
84+
)
85+
86+
87+
def abstract_eval(abstract_inputs):
88+
"""Calculate output shape of apply from the shape of its inputs."""
89+
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
90+
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
91+
92+
jaxified_inputs = jax.tree.map(
93+
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
94+
abstract_inputs.model_dump(),
95+
is_leaf=is_shapedtype_dict,
96+
)
97+
dynamic_inputs, static_inputs = eqx.partition(
98+
jaxified_inputs, filter_spec=is_shapedtype_struct
99+
)
100+
101+
def wrapped_apply(dynamic_inputs):
102+
inputs = eqx.combine(static_inputs, dynamic_inputs)
103+
return apply_jit(inputs)
104+
105+
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
106+
return jax.tree.map(
107+
lambda x: (
108+
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x
109+
),
110+
jax_shapes,
111+
is_leaf=is_shapedtype_struct,
112+
)
113+
114+
115+
@eqx.filter_jit
116+
def jac_jit(
117+
inputs: dict,
118+
jac_inputs: tuple[str],
119+
jac_outputs: tuple[str],
120+
):
121+
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
122+
return jax.jacrev(filtered_apply)(
123+
flatten_with_paths(inputs, include_paths=jac_inputs)
124+
)
125+
126+
127+
@eqx.filter_jit
128+
def jvp_jit(
129+
inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict
130+
):
131+
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
132+
return jax.jvp(
133+
filtered_apply,
134+
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
135+
[tangent_vector],
136+
)
137+
138+
139+
@eqx.filter_jit
140+
def vjp_jit(
141+
inputs: dict,
142+
vjp_inputs: tuple[str],
143+
vjp_outputs: tuple[str],
144+
cotangent_vector: dict,
145+
):
146+
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
147+
_, vjp_func = jax.vjp(
148+
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
149+
)
150+
return vjp_func(cotangent_vector)[0]
151+
152+
64153
def cfd_fwd(
65154
v0: jnp.ndarray,
66155
density: float,
@@ -89,7 +178,7 @@ def cfd_fwd(
89178
vx0 = cfd.grids.GridArray(vx0, grid=grid, offset=(1.0, 0.5))
90179
vy0 = cfd.grids.GridArray(vy0, grid=grid, offset=(0.5, 1.0))
91180

92-
# reconstrut GridVariable from input
181+
# reconstruct GridVariable from input
93182
vx0 = cfd.grids.GridVariable(vx0, bc)
94183
vy0 = cfd.grids.GridVariable(vy0, bc)
95184
v0 = (vx0, vy0)
@@ -106,80 +195,8 @@ def cfd_fwd(
106195
),
107196
steps=inner_steps,
108197
)
109-
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
198+
rollout_fn = cfd.funcutils.trajectory(step_fn, outer_steps)
110199
_, trajectory = jax.device_get(rollout_fn(v0))
111-
112200
vxn = trajectory[0].array.data[-1]
113-
114201
vyn = trajectory[1].array.data[-1]
115-
116202
return jnp.stack([vxn, vyn], axis=-1)
117-
118-
119-
def apply(inputs: InputSchema) -> OutputSchema: #
120-
vn = cfd_fwd(
121-
v0=inputs.v0,
122-
density=inputs.density,
123-
viscosity=inputs.viscosity,
124-
inner_steps=inputs.inner_steps,
125-
outer_steps=inputs.outer_steps,
126-
max_velocity=inputs.max_velocity,
127-
cfl_safety_factor=inputs.cfl_safety_factor,
128-
domain_size_x=inputs.domain_size_x,
129-
domain_size_y=inputs.domain_size_y,
130-
)
131-
132-
return OutputSchema(result=vn)
133-
134-
135-
def abstract_eval(abstract_inputs):
136-
"""Calculate output shape of apply from the shape of its inputs."""
137-
return {
138-
"result": ShapeDType(shape=abstract_inputs.v0.shape, dtype="float32"),
139-
}
140-
141-
142-
def vector_jacobian_product(
143-
inputs: InputSchema,
144-
vjp_inputs: set[str],
145-
vjp_outputs: set[str],
146-
cotangent_vector,
147-
):
148-
signature = [
149-
"v0",
150-
"density",
151-
"viscosity",
152-
"inner_steps",
153-
"outer_steps",
154-
"max_velocity",
155-
"cfl_safety_factor",
156-
"domain_size_x",
157-
"domain_size_y",
158-
]
159-
# We need to do this, rather than just use jvp inputs, as the order in jvp_inputs
160-
# is not necessarily the same as the ordering of the args in the function signature.
161-
static_args = [arg for arg in signature if arg not in vjp_inputs]
162-
nonstatic_args = [arg for arg in signature if arg in vjp_inputs]
163-
164-
def cfd_fwd_reordered(*args, **kwargs):
165-
return cfd_fwd(
166-
**{**{arg: args[i] for i, arg in enumerate(nonstatic_args)}, **kwargs}
167-
)
168-
169-
out = {}
170-
if "result" in vjp_outputs:
171-
# Make the function depend only on nonstatic args, as jax.jvp
172-
# differentiates w.r.t. all free arguments.
173-
func = partial(
174-
cfd_fwd_reordered, **{arg: getattr(inputs, arg) for arg in static_args}
175-
)
176-
177-
_, vjp_func = jax.vjp(
178-
func, *tuple(inputs.model_dump(include=vjp_inputs).values())
179-
)
180-
181-
vals = vjp_func(cotangent_vector["result"])
182-
for arg, val in zip(nonstatic_args, vals, strict=False):
183-
out[arg] = out.get(arg, 0.0) + val
184-
185-
return out
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy==1.26.4
22
jax-cfd==0.2.1
3-
jax[cpu]==0.4.34
3+
jax[cpu]==0.6.0
4+
equinox

0 commit comments

Comments
 (0)