Skip to content

Commit 10a9a86

Browse files
committed
optim working
1 parent e388269 commit 10a9a86

File tree

8 files changed

+2865
-140
lines changed

8 files changed

+2865
-140
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,5 @@ dmypy.json
141141
# Ignore UV dev environments
142142
uv.lock
143143
uv.lock.bak
144+
145+
tmp_img/

examples/ansys/demo_2.ipynb

Lines changed: 2792 additions & 105 deletions
Large diffs are not rendered by default.
-430 KB
Binary file not shown.

examples/ansys/fem_tess/tesseract_api.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,20 @@ def get_tensor_map(self) -> Callable:
106106
"""
107107

108108
def stress(u_grad, theta):
109-
E = 70e3
109+
Emax = 70.0e3
110+
Emin = 1e-3 * Emax
111+
penal = 3.0
112+
113+
E = Emin + (Emax - Emin) * theta[0] ** penal
114+
110115
nu = 0.3
111116
mu = E / (2.0 * (1.0 + nu))
112117
lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))
113118

114119
epsilon = 0.5 * (u_grad + u_grad.T)
115120
sigma = lmbda * jnp.trace(epsilon) * jnp.eye(self.dim) + 2 * mu * epsilon
121+
122+
sigma = lmbda * jnp.trace(epsilon) * jnp.eye(self.dim) + 2.0 * mu * epsilon
116123
return sigma
117124

118125
return stress
@@ -123,13 +130,7 @@ def get_surface_maps(self) -> list[Callable]:
123130
Returns:
124131
List of van Neumann boundary condition value functions.
125132
"""
126-
127-
def surface_map(u, x):
128-
return jnp.array([0.0, 0.0, 100.0])
129-
130-
return [surface_map]
131-
132-
# return self.van_neumann_value_fns
133+
return self.van_neumann_value_fns
133134

134135
def set_params(self, params: jnp.ndarray) -> None:
135136
"""Set density parameters for topology optimization.
@@ -280,6 +281,41 @@ def apply_fn(inputs: dict) -> dict:
280281
Returns:
281282
Dictionary containing the compliance of the structure.
282283
"""
284+
from typing import TypeVar
285+
286+
T = TypeVar("T")
287+
288+
def stop_grads_int(x: T) -> T:
289+
"""Stops gradient computation.
290+
291+
We cannot use jax.lax.stop_gradient directly because Tesseract meshes are
292+
nested dictionaries with arrays and integers, and jax.lax.stop_gradient
293+
does not support integers.
294+
295+
Args:
296+
x: Input value.
297+
298+
Returns:
299+
Value with stopped gradients.
300+
"""
301+
302+
def stop(x):
303+
return jax._src.ad_util.stop_gradient_p.bind(x)
304+
305+
return jax.tree_util.tree_map(stop, x)
306+
307+
# stop grads on all inputs except rho
308+
309+
# problem, fwd_pred = setup(
310+
# pts=stop_grads_int(inputs["hex_mesh"]["points"][: inputs["hex_mesh"]["n_points"]]),
311+
# cells=stop_grads_int(inputs["hex_mesh"]["faces"][: inputs["hex_mesh"]["n_faces"]]),
312+
# dirichlet_mask=stop_grads_int(inputs["dirichlet_mask"]),
313+
# dirichlet_values=stop_grads_int(inputs["dirichlet_values"]),
314+
# van_neumann_mask=stop_grads_int(inputs["van_neumann_mask"]),
315+
# van_neumann_values=stop_grads_int(inputs["van_neumann_values"]),
316+
# )
317+
318+
# no stop grads
283319
problem, fwd_pred = setup(
284320
pts=inputs["hex_mesh"]["points"][: inputs["hex_mesh"]["n_points"]],
285321
cells=inputs["hex_mesh"]["faces"][: inputs["hex_mesh"]["n_faces"]],
@@ -289,7 +325,7 @@ def apply_fn(inputs: dict) -> dict:
289325
van_neumann_values=inputs["van_neumann_values"],
290326
)
291327

292-
rho = inputs["rho"][: inputs["hex_mesh"]["n_points"]]
328+
rho = inputs["rho"][: inputs["hex_mesh"]["n_faces"]]
293329

294330
sol_list = fwd_pred(rho)
295331
compliance = problem.compute_compliance(sol_list[0])

examples/ansys/gf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
images = []
44

5-
for i in range(60):
6-
filename = f"img/mesh_optim_{i:03d}.png"
5+
for i in range(30):
6+
filename = f"tmp_img/mesh_optim_{i:03d}.png"
77
images.append(imageio.imread(filename))
88
print(f"Added {filename} to gif.")
99
imageio.mimsave("mesh_optim.gif", images, fps=10)

examples/ansys/hot_design_tess/tesseract_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def vector_jacobian_product(
313313
epsilon=inputs.epsilon,
314314
)
315315
if inputs.normalize_jacobian:
316-
n_elements = inputs.Nx * inputs.Ny * inputs.Nz
316+
n_elements = (
317+
inputs.grid_elements[0] * inputs.grid_elements[1] * inputs.grid_elements[2]
318+
)
317319
jac = jac / n_elements
318320
# Reduce the cotangent vector to the shape of the Jacobian, to compute VJP by hand
319321
vjp = np.einsum("klmn,lmn->k", jac, cotangent_vector["sdf"]).astype(np.float32)

examples/ansys/mesh_optim.gif

-1.17 MB
Loading

examples/ansys/meshing_tess/tesseract_api.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,8 @@ class InputSchema(BaseModel):
3030
"Sizing field values defined on a regular grid for mesh adaptation."
3131
)
3232
)
33-
Lx: float = Field(
34-
default=60.0,
35-
description=("Length of the domain in the x direction. "),
36-
)
37-
Ly: float = Field(
38-
default=30.0,
39-
description=("Length of the domain in the y direction. "),
40-
)
41-
Lz: float = Field(
42-
default=30.0,
43-
description=("Length of the domain in the z direction. "),
33+
domain_size: tuple[float, float, float] = Field(
34+
description=("Size of the domain in x, y, z directions.")
4435
)
4536

4637
max_points: int = Field(
@@ -326,10 +317,13 @@ def apply(inputs: InputSchema) -> OutputSchema:
326317
Returns:
327318
OutputSchema, outputs of the function.
328319
"""
320+
Lx = inputs.domain_size[0]
321+
Ly = inputs.domain_size[1]
322+
Lz = inputs.domain_size[2]
329323
pts, cells = generate_mesh(
330-
Lx=inputs.Lx,
331-
Ly=inputs.Ly,
332-
Lz=inputs.Lz,
324+
Lx=Lx,
325+
Ly=Ly,
326+
Lz=Lz,
333327
sizing_field=inputs.sizing_field,
334328
max_levels=inputs.max_subdivision_levels,
335329
)
@@ -339,9 +333,9 @@ def apply(inputs: InputSchema) -> OutputSchema:
339333
cells_padded = jnp.zeros((inputs.max_cells, 8), dtype=cells.dtype)
340334
cells_padded = cells_padded.at[: cells.shape[0], :].set(cells)
341335

342-
xs = jnp.linspace(-inputs.Lx / 2, inputs.Lx / 2, inputs.field_values.shape[0])
343-
ys = jnp.linspace(-inputs.Ly / 2, inputs.Ly / 2, inputs.field_values.shape[1])
344-
zs = jnp.linspace(-inputs.Lz / 2, inputs.Lz / 2, inputs.field_values.shape[2])
336+
xs = jnp.linspace(-Lx / 2, Lx / 2, inputs.field_values.shape[0])
337+
ys = jnp.linspace(-Ly / 2, Ly / 2, inputs.field_values.shape[1])
338+
zs = jnp.linspace(-Lz / 2, Lz / 2, inputs.field_values.shape[2])
345339

346340
interpolator = RegularGridInterpolator(
347341
(xs, ys, zs),
@@ -395,19 +389,23 @@ def vector_jacobian_product(
395389
assert vjp_inputs == {"field_values"}
396390
assert vjp_outputs == {"mesh_cell_values"}
397391

392+
Lx = inputs.domain_size[0]
393+
Ly = inputs.domain_size[1]
394+
Lz = inputs.domain_size[2]
395+
398396
pts, cells = generate_mesh(
399-
Lx=inputs.Lx,
400-
Ly=inputs.Ly,
401-
Lz=inputs.Lz,
397+
Lx=Lx,
398+
Ly=Ly,
399+
Lz=Lz,
402400
sizing_field=inputs.sizing_field,
403401
max_levels=inputs.max_subdivision_levels,
404402
)
405403

406404
cell_centers = jnp.mean(pts[cells], axis=1)
407405

408-
xs = jnp.linspace(-inputs.Lx / 2, inputs.Lx / 2, inputs.field_values.shape[0])
409-
ys = jnp.linspace(-inputs.Ly / 2, inputs.Ly / 2, inputs.field_values.shape[1])
410-
zs = jnp.linspace(-inputs.Lz / 2, inputs.Lz / 2, inputs.field_values.shape[2])
406+
xs = jnp.linspace(-Lx / 2, Lx / 2, inputs.field_values.shape[0])
407+
ys = jnp.linspace(-Ly / 2, Ly / 2, inputs.field_values.shape[1])
408+
zs = jnp.linspace(-Lz / 2, Lz / 2, inputs.field_values.shape[2])
411409
xs, ys, zs = jnp.meshgrid(xs, ys, zs, indexing="ij")
412410

413411
field_cotangent_vector = griddata(

0 commit comments

Comments
 (0)