Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 126 additions & 90 deletions demo/cfd/cfd-tesseract/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
import jax_cfd.base as cfd
from pydantic import BaseModel, Field
from tesseract_core.runtime import Array, Differentiable, Float32, ShapeDType

# TODO: !! Use JAX recipe for this, to avoid re-jitting of VJPs etc. !!
from tesseract_core.runtime import Array, Differentiable, Float32
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths


class InputSchema(BaseModel):
Expand All @@ -22,45 +22,25 @@ class InputSchema(BaseModel):
),
Float32,
]
] = Field(description="3D Array defining the initial velocity field [...]")
] = Field(description="3D Array defining the initial velocity field")
density: float = Field(description="Density of the fluid")
viscosity: float = Field(description="Viscosity of the fluid")
inner_steps: float = Field(
inner_steps: int = Field(
description="Number of solver steps for each timestep", default=25
)
outer_steps: float = Field(description="Number of timesteps steps", default=10)
outer_steps: int = Field(description="Number of timesteps steps", default=10)
max_velocity: float = Field(description="Maximum velocity", default=2.0)
cfl_safety_factor: float = Field(description="CFL safety factor", default=0.5)
domain_size_x: float = Field(description="Domain size x", default=1.0)
domain_size_y: float = Field(description="Domain size y", default=1.0)


class OutputSchema(BaseModel):
result: Differentiable[
Array[
(
None,
None,
None,
),
Float32,
]
] = Field(description="3D Array defining the final velocity field [...]")


@partial(
jax.jit,
static_argnames=(
"density",
"viscosity",
"inner_steps",
"outer_steps",
"max_velocity",
"cfl_safety_factor",
"domain_size_x",
"domain_size_y",
),
)
result: Differentiable[Array[(None, None, None), Float32]] = Field(
description="3D Array defining the final velocity field"
)


def cfd_fwd(
v0: jnp.ndarray,
density: float,
Expand All @@ -72,6 +52,22 @@ def cfd_fwd(
domain_size_x: float,
domain_size_y: float,
) -> tuple[jax.Array, jax.Array]:
"""Compute the final velocity field using the semi-implicit Navier-Stokes equations.

Args:
v0: Initial velocity field.
density: Density of the fluid.
viscosity: Viscosity of the fluid.
inner_steps: Number of solver steps for each timestep.
outer_steps: Number of timesteps steps.
max_velocity: Maximum velocity.
cfl_safety_factor: CFL safety factor.
domain_size_x: Domain size in x direction.
domain_size_y: Domain size in y direction.

Returns:
Final velocity field.
"""
vx0 = v0[..., 0]
vy0 = v0[..., 1]
bc = cfd.boundaries.HomogeneousBoundaryConditions(
Expand All @@ -89,7 +85,7 @@ def cfd_fwd(
vx0 = cfd.grids.GridArray(vx0, grid=grid, offset=(1.0, 0.5))
vy0 = cfd.grids.GridArray(vy0, grid=grid, offset=(0.5, 1.0))

# reconstrut GridVariable from input
# reconstruct GridVariable from input
vx0 = cfd.grids.GridVariable(vx0, bc)
vy0 = cfd.grids.GridVariable(vy0, bc)
v0 = (vx0, vy0)
Expand All @@ -106,80 +102,120 @@ def cfd_fwd(
),
steps=inner_steps,
)
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
rollout_fn = cfd.funcutils.trajectory(step_fn, outer_steps)
_, trajectory = jax.device_get(rollout_fn(v0))

vxn = trajectory[0].array.data[-1]

vyn = trajectory[1].array.data[-1]

return jnp.stack([vxn, vyn], axis=-1)


def apply(inputs: InputSchema) -> OutputSchema: #
vn = cfd_fwd(
v0=inputs.v0,
density=inputs.density,
viscosity=inputs.viscosity,
inner_steps=inputs.inner_steps,
outer_steps=inputs.outer_steps,
max_velocity=inputs.max_velocity,
cfl_safety_factor=inputs.cfl_safety_factor,
domain_size_x=inputs.domain_size_x,
domain_size_y=inputs.domain_size_y,
)
@eqx.filter_jit
def apply_jit(inputs: dict) -> dict:
vn = cfd_fwd(**inputs)
return dict(result=vn)

return OutputSchema(result=vn)

def apply(inputs: InputSchema) -> OutputSchema:
return apply_jit(inputs.model_dump())


def jacobian(
inputs: InputSchema,
jac_inputs: set[str],
jac_outputs: set[str],
):
return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs))

def abstract_eval(abstract_inputs):
"""Calculate output shape of apply from the shape of its inputs."""
return {
"result": ShapeDType(shape=abstract_inputs.v0.shape, dtype="float32"),
}

def jacobian_vector_product(
inputs: InputSchema,
jvp_inputs: set[str],
jvp_outputs: set[str],
tangent_vector: dict[str, Any],
):
return jvp_jit(
inputs.model_dump(),
tuple(jvp_inputs),
tuple(jvp_outputs),
tangent_vector,
)


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector,
cotangent_vector: dict[str, Any],
):
signature = [
"v0",
"density",
"viscosity",
"inner_steps",
"outer_steps",
"max_velocity",
"cfl_safety_factor",
"domain_size_x",
"domain_size_y",
]
# We need to do this, rather than just use jvp inputs, as the order in jvp_inputs
# is not necessarily the same as the ordering of the args in the function signature.
static_args = [arg for arg in signature if arg not in vjp_inputs]
nonstatic_args = [arg for arg in signature if arg in vjp_inputs]

def cfd_fwd_reordered(*args, **kwargs):
return cfd_fwd(
**{**{arg: args[i] for i, arg in enumerate(nonstatic_args)}, **kwargs}
)
return vjp_jit(
inputs.model_dump(),
tuple(vjp_inputs),
tuple(vjp_outputs),
cotangent_vector,
)

out = {}
if "result" in vjp_outputs:
# Make the function depend only on nonstatic args, as jax.jvp
# differentiates w.r.t. all free arguments.
func = partial(
cfd_fwd_reordered, **{arg: getattr(inputs, arg) for arg in static_args}
)

_, vjp_func = jax.vjp(
func, *tuple(inputs.model_dump(include=vjp_inputs).values())
)
def abstract_eval(abstract_inputs):
"""Calculate output shape of apply from the shape of its inputs."""
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)

jaxified_inputs = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
abstract_inputs.model_dump(),
is_leaf=is_shapedtype_dict,
)
dynamic_inputs, static_inputs = eqx.partition(
jaxified_inputs, filter_spec=is_shapedtype_struct
)

def wrapped_apply(dynamic_inputs):
inputs = eqx.combine(static_inputs, dynamic_inputs)
return apply_jit(inputs)

jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
return jax.tree.map(
lambda x: (
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x
),
jax_shapes,
is_leaf=is_shapedtype_struct,
)


@eqx.filter_jit
def jac_jit(
inputs: dict,
jac_inputs: tuple[str],
jac_outputs: tuple[str],
):
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
return jax.jacrev(filtered_apply)(
flatten_with_paths(inputs, include_paths=jac_inputs)
)

vals = vjp_func(cotangent_vector["result"])
for arg, val in zip(nonstatic_args, vals, strict=False):
out[arg] = out.get(arg, 0.0) + val

return out
@eqx.filter_jit
def jvp_jit(
inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict
):
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
return jax.jvp(
filtered_apply,
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
[tangent_vector],
)


@eqx.filter_jit
def vjp_jit(
inputs: dict,
vjp_inputs: tuple[str],
vjp_outputs: tuple[str],
cotangent_vector: dict,
):
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
_, vjp_func = jax.vjp(
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
)
return vjp_func(cotangent_vector)[0]
3 changes: 2 additions & 1 deletion demo/cfd/cfd-tesseract/tesseract_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy==1.26.4
jax-cfd==0.2.1
jax[cpu]==0.4.34
jax[cpu]==0.6.0
equinox
Loading