Skip to content

Commit 7150049

Browse files
authored
doc: Add shape optimization demo (#45)
#### Relevant issue or PR n/a #### Description of changes This pulls in the shape optimization demo from our SciPy proceedings paper and enables testing on CI + rendering in the docs. #### Testing done manual + CI
1 parent 7a6db32 commit 7150049

File tree

17 files changed

+28248
-9
lines changed

17 files changed

+28248
-9
lines changed

.github/workflows/test_examples.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ on:
55
branches:
66
- main
77

8+
pull_request:
9+
paths:
10+
- examples/**
11+
812
jobs:
913
tests:
1014
strategy:
@@ -19,6 +23,7 @@ jobs:
1923
example:
2024
- simple
2125
- cfd
26+
- fem-shapeopt
2227

2328
fail-fast: false
2429

@@ -41,6 +46,11 @@ jobs:
4146
- name: Restore UV environment
4247
run: cp production.uv.lock uv.lock
4348

49+
- name: Install system requirements
50+
run: |
51+
sudo apt-get update
52+
sudo apt-get install -y libosmesa6
53+
4454
- name: Install dev requirements
4555
run: |
4656
uv sync --extra dev --frozen

docs/conf.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from tesseract_jax import __version__
1414

15+
here = Path(__file__).parent.resolve()
16+
1517
project = "Tesseract-JAX"
1618
copyright = "2025, Pasteur Labs"
1719
author = "The Tesseract-JAX Team @ Pasteur Labs + OSS contributors"
@@ -78,8 +80,9 @@
7880
# Do not execute notebooks during build (just take existing output)
7981
nb_execution_mode = "off"
8082

81-
# Copy example notebooks to demo_notebooks folder on every build
82-
for example_notebook in Path("../examples").glob("*/demo.ipynb"):
83-
# Copy the example notebook to the docs folder
84-
dest = (Path("demo_notebooks") / example_notebook.parent.name).with_suffix(".ipynb")
85-
shutil.copyfile(example_notebook, dest)
83+
# Copy example notebooks to docs/examples folder on every build
84+
for example_dir in Path("../examples").glob("*/"):
85+
# Copy the example directory to the docs folder
86+
shutil.copytree(
87+
example_dir, here / "examples" / example_dir.name, dirs_exist_ok=True
88+
)

docs/demo_notebooks/.gitignore

Lines changed: 0 additions & 1 deletion
This file was deleted.

docs/examples/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
**/*
2+
!.gitignore

docs/index.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def vector_sum(x, y):
1313
jax.grad(vector_sum)(x, y) # 🎉
1414
```
1515

16-
Want to learn more? See how to [get started](content/get-started.md) with Tesseract-JAX, explore the [API reference](content/api.md), or learn by [example](demo_notebooks/simple.ipynb).
16+
Want to learn more? See how to [get started](content/get-started.md) with Tesseract-JAX, explore the [API reference](content/api.md), or learn by [example](examples/simple/demo.ipynb).
1717

1818
## License
1919

@@ -36,8 +36,9 @@ content/api
3636
:maxdepth: 2
3737
:hidden:
3838
39-
demo_notebooks/simple.ipynb
40-
demo_notebooks/cfd.ipynb
39+
examples/simple/demo.ipynb
40+
examples/cfd/demo.ipynb
41+
examples/fem-shapeopt/demo.ipynb
4142
```
4243

4344
```{toctree}

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ This directory contains example Tesseract configurations, notebooks. and scripts
66

77
- [Simple](simple/demo.ipynb): A basic example of using Tesseract-JAX with a simple vector addition task. It demonstrates how to build a Tesseract and execute it within JAX.
88
- [CFD](cfd/demo.ipynb): A more complex example demonstrating how to use Tesseract-JAX to differentiate through a computational fluid dynamics (CFD) simulation in an optimization context.
9+
- [FEM Shape Optimization](fem-shapeopt/demo.ipynb): A step-by-step guide to implementing a parametric shape optimization pipeline using Tesseract-JAX, involving multiple Tesseracts working together to optimize a design based on finite element method (FEM) simulations.

examples/fem-shapeopt/demo.ipynb

Lines changed: 27664 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import pyvista as pv
5+
from pydantic import BaseModel, Field
6+
from tesseract_core.runtime import Array, Differentiable, Float32, ShapeDType
7+
8+
#
9+
# Schemata
10+
#
11+
12+
13+
class InputSchema(BaseModel):
14+
bar_params: Differentiable[
15+
Array[
16+
(None, None, 3),
17+
Float32,
18+
]
19+
] = Field(
20+
description=(
21+
"Vertex positions of the bar geometry. "
22+
"The shape is (num_bars, num_vertices, 3), where num_bars is the number of bars "
23+
"and num_vertices is the number of vertices per bar. The last dimension represents "
24+
"the x, y, z coordinates of each vertex."
25+
)
26+
)
27+
28+
bar_radius: float = Field(
29+
default=1.5,
30+
description=(
31+
"Radius of the bars in the geometry. "
32+
"This is a scalar value that defines the thickness of the bars."
33+
),
34+
)
35+
36+
Lx: float = Field(
37+
default=60.0,
38+
description=(
39+
"Length of the plane in the x direction. "
40+
"This is a scalar value that defines the size of the plane along the x-axis."
41+
),
42+
)
43+
Ly: float = Field(
44+
default=30.0,
45+
description=(
46+
"Length of the plane in the y direction. "
47+
"This is a scalar value that defines the size of the plane along the y-axis."
48+
),
49+
)
50+
Nx: int = Field(
51+
default=60,
52+
description=(
53+
"Number of points in the x direction. "
54+
"This is an integer value that defines the resolution of the plane along the x-axis."
55+
),
56+
)
57+
Ny: int = Field(
58+
default=30,
59+
description=(
60+
"Number of points in the y direction. "
61+
"This is an integer value that defines the resolution of the plane along the y-axis."
62+
),
63+
)
64+
epsilon: float = Field(
65+
default=1e-5,
66+
description=(
67+
"Epsilon value for finite difference approximation of the Jacobian. "
68+
"This is a small scalar value used to compute the numerical gradient."
69+
),
70+
)
71+
72+
73+
class OutputSchema(BaseModel):
74+
sdf: Differentiable[
75+
Array[
76+
(
77+
None,
78+
None,
79+
),
80+
Float32,
81+
]
82+
] = Field(description="SDF field of the geometry")
83+
84+
85+
#
86+
# Helper functions
87+
#
88+
89+
90+
def build_geometry(
91+
params: np.ndarray,
92+
radius: float,
93+
) -> list[pv.PolyData]:
94+
"""Build a pyvista geometry from the parameters.
95+
96+
The parameters are expected to be of shape (n_chains, n_edges_per_chain + 1, 3),
97+
"""
98+
n_chains = params.shape[0]
99+
geometry = []
100+
101+
for chain in range(n_chains):
102+
tube = pv.Spline(points=params[chain]).tube(radius=radius, capping=False)
103+
geometry.append(tube)
104+
105+
return geometry
106+
107+
108+
def compute_sdf(
109+
params: np.ndarray,
110+
radius: float,
111+
Lx: float,
112+
Ly: float,
113+
Nx: int,
114+
Ny: int,
115+
) -> pv.PolyData:
116+
"""Create a pyvista plane that has the SDF values stored as a vertex attribute.
117+
118+
The SDF field is computed based on the geometry defined by the parameters.
119+
"""
120+
grid_coords = pv.Plane(
121+
center=(0, 0, 0),
122+
direction=(0, 0, 1),
123+
i_size=Lx,
124+
j_size=Ly,
125+
i_resolution=Nx - 1,
126+
j_resolution=Ny - 1,
127+
)
128+
grid_coords = grid_coords.triangulate()
129+
130+
geometries = build_geometry(
131+
params,
132+
radius=radius,
133+
)
134+
135+
sdf_field = None
136+
137+
for geometry in geometries:
138+
# Compute the implicit distance from the geometry to the grid coordinates.
139+
# The implicit distance is a signed distance field, where positive values
140+
# are outside the geometry and negative values are inside.
141+
this_sdf = grid_coords.compute_implicit_distance(geometry.triangulate())
142+
if sdf_field is None:
143+
sdf_field = this_sdf
144+
else:
145+
sdf_field["implicit_distance"] = np.minimum(
146+
sdf_field["implicit_distance"], this_sdf["implicit_distance"]
147+
)
148+
149+
return sdf_field
150+
151+
152+
def apply_fn(
153+
params: np.ndarray,
154+
radius: float,
155+
Lx: float,
156+
Ly: float,
157+
Nx: int,
158+
Ny: int,
159+
) -> np.ndarray:
160+
"""Get the sdf values of a the geometry defined by the parameters as a 2D array."""
161+
sdf_geom = compute_sdf(
162+
params,
163+
radius=radius,
164+
Lx=Lx,
165+
Ly=Ly,
166+
Nx=Nx,
167+
Ny=Ny,
168+
)["implicit_distance"]
169+
170+
# The implicit distance is a 1D where the indexing is tranposed.
171+
# We need to reshape it to a 2D array with the shape (Ny, Nx) and then transpose it to get the correct orientation.
172+
return sdf_geom.reshape((Ny, Nx)).T
173+
174+
175+
def jac_sdf_wrt_params(
176+
params: np.ndarray,
177+
radius: float,
178+
Lx: float,
179+
Ly: float,
180+
Nx: int,
181+
Ny: int,
182+
epsilon: float,
183+
) -> np.ndarray:
184+
"""Compute the Jacobian of the SDF values with respect to the parameters.
185+
186+
The Jacobian is computed by finite differences.
187+
The shape of the Jacobian is (n_chains, n_edges_per_chain + 1, 3, Nx, Ny).
188+
"""
189+
n_chains = params.shape[0]
190+
n_edges_per_chain = params.shape[1] - 1
191+
192+
jac = np.zeros(
193+
(
194+
n_chains,
195+
n_edges_per_chain + 1,
196+
3, # number of dimensions (x, y, z)
197+
Nx,
198+
Ny,
199+
)
200+
)
201+
202+
sdf_base = apply_fn(
203+
params,
204+
radius=radius,
205+
Lx=Lx,
206+
Ly=Ly,
207+
Nx=Nx,
208+
Ny=Ny,
209+
)
210+
211+
for chain in range(n_chains):
212+
for vertex in range(0, n_edges_per_chain + 1):
213+
# we only care about the y coordinate
214+
i = 1
215+
params_eps = params.copy()
216+
params_eps[chain, vertex, i] += epsilon
217+
218+
sdf_epsilon = apply_fn(
219+
params_eps,
220+
radius=radius,
221+
Lx=Lx,
222+
Ly=Ly,
223+
Nx=Nx,
224+
Ny=Ny,
225+
)
226+
jac[chain, vertex, i] = (sdf_epsilon - sdf_base) / epsilon
227+
228+
return jac
229+
230+
231+
#
232+
# Tesseract endpoints
233+
#
234+
235+
236+
def apply(inputs: InputSchema) -> OutputSchema:
237+
return OutputSchema(
238+
sdf=apply_fn(
239+
inputs.bar_params,
240+
radius=inputs.bar_radius,
241+
Lx=inputs.Lx,
242+
Ly=inputs.Ly,
243+
Nx=inputs.Nx,
244+
Ny=inputs.Ny,
245+
)
246+
)
247+
248+
249+
def vector_jacobian_product(
250+
inputs: InputSchema,
251+
vjp_inputs: set[str],
252+
vjp_outputs: set[str],
253+
cotangent_vector: dict[str, Any],
254+
):
255+
assert vjp_inputs == {"bar_params"}
256+
assert vjp_outputs == {"sdf"}
257+
258+
jac = jac_sdf_wrt_params(
259+
inputs.bar_params,
260+
radius=inputs.bar_radius,
261+
Lx=inputs.Lx,
262+
Ly=inputs.Ly,
263+
Nx=inputs.Nx,
264+
Ny=inputs.Ny,
265+
epsilon=inputs.epsilon,
266+
)
267+
# Reduce the cotangent vector to the shape of the Jacobian, to compute VJP by hand
268+
vjp = np.einsum("ijklm,lm->ijk", jac, cotangent_vector["sdf"]).astype(np.float32)
269+
return {"bar_params": vjp}
270+
271+
272+
def abstract_eval(abstract_inputs):
273+
"""Calculate output shape of apply from the shape of its inputs."""
274+
return {
275+
"sdf": ShapeDType(
276+
shape=(abstract_inputs.Nx, abstract_inputs.Ny), dtype="float32"
277+
)
278+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: design-tube-sdf
2+
version: "0.1.0"
3+
description: |
4+
Tesseract that generates a gridded signed distance function (SDF) for a set of shape parameters.
5+
6+
Parameters are expected to define the control points and radii of piecewise linear tubes in 3D space.
7+
8+
Has a VJP endpoint defined that uses finite differences under the hood.
9+
10+
build_config:
11+
target_platform: "linux/x86_64"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
numpy==1.26.4
2+
pyvista==0.45.2

0 commit comments

Comments
 (0)