Skip to content

Commit 8041a81

Browse files
committed
utils
1 parent dd1b782 commit 8041a81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3280
-500
lines changed

examples/ansys/Readme.md

Lines changed: 3 additions & 4 deletions

examples/ansys/bars_mesh.vtk

-1.29 KB
Binary file not shown.

examples/ansys/grid_mesh.vtk

0 Bytes
Binary file not shown.

examples/ansys/optim_bars.ipynb

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,10 +2262,22 @@
22622262
},
22632263
{
22642264
"cell_type": "code",
2265-
"execution_count": 380,
2265+
"execution_count": 1,
22662266
"id": "a2a2a6b1",
22672267
"metadata": {},
2268-
"outputs": [],
2268+
"outputs": [
2269+
{
2270+
"ename": "NameError",
2271+
"evalue": "name 'n_steps' is not defined",
2272+
"output_type": "error",
2273+
"traceback": [
2274+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
2275+
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
2276+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[43mn_steps\u001b[49m):\n\u001b[32m 2\u001b[39m mesh = aux_hist[i][\u001b[33m\"\u001b[39m\u001b[33mhex_mesh\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 3\u001b[39m rho_dot = grad_storage[i + \u001b[32m3000\u001b[39m][\u001b[32m1\u001b[39m][: \u001b[38;5;28mlen\u001b[39m(mesh[\u001b[33m\"\u001b[39m\u001b[33mfaces\u001b[39m\u001b[33m\"\u001b[39m])][:, \u001b[32m0\u001b[39m]\n",
2277+
"\u001b[31mNameError\u001b[39m: name 'n_steps' is not defined"
2278+
]
2279+
}
2280+
],
22692281
"source": [
22702282
"for i in range(n_steps):\n",
22712283
" mesh = aux_hist[i][\"hex_mesh\"]\n",
@@ -2278,10 +2290,22 @@
22782290
},
22792291
{
22802292
"cell_type": "code",
2281-
"execution_count": 381,
2293+
"execution_count": 2,
22822294
"id": "bd470372",
22832295
"metadata": {},
2284-
"outputs": [],
2296+
"outputs": [
2297+
{
2298+
"ename": "NameError",
2299+
"evalue": "name 'params_hist' is not defined",
2300+
"output_type": "error",
2301+
"traceback": [
2302+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
2303+
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
2304+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m animation\n\u001b[32m 3\u001b[39m \u001b[38;5;66;03m# repeat the last frame a few times to show the final result\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m params_hist = \u001b[43mparams_hist\u001b[49m + [params] * \u001b[32m20\u001b[39m\n\u001b[32m 6\u001b[39m fig = plt.figure(figsize=(\u001b[32m7\u001b[39m, \u001b[32m4\u001b[39m))\n\u001b[32m 8\u001b[39m design_inputs[\u001b[33m\"\u001b[39m\u001b[33mprecompute_jacobian\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28;01mFalse\u001b[39;00m\n",
2305+
"\u001b[31mNameError\u001b[39m: name 'params_hist' is not defined"
2306+
]
2307+
}
2308+
],
22852309
"source": [
22862310
"from matplotlib import animation\n",
22872311
"\n",

examples/ansys/optim_bars_pymapdl.ipynb

Lines changed: 2130 additions & 0 deletions
Large diffs are not rendered by default.

examples/ansys/optim_grid.ipynb

Lines changed: 957 additions & 490 deletions
Large diffs are not rendered by default.

examples/ansys/out

Whitespace-only changes.

examples/ansys/rho_optim_sum_2.gif

117 KB

examples/ansys/sdf_fd_tess/tesseract_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,7 @@ def vector_jacobian_product(
519519
assert vjp_outputs == {"sdf"}
520520

521521
jac = jacobian(inputs, vjp_inputs, vjp_outputs)["sdf"]["differentiable_parameters"]
522-
# Reduce the cotangent vector to the shape of the Jacobian, to compute VJP by hand
523-
vjp = np.einsum("klmn,lmn->k", jac, cotangent_vector["sdf"])
522+
vjp = np.einsum("lmn,klmn->k", cotangent_vector["sdf"], jac)
524523

525524
return {"differentiable_parameters": vjp}
526525

examples/ansys/utils.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from collections.abc import Sequence
2+
from typing import TypeVar
3+
4+
import jax
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import pyvista as pv
8+
from mpl_toolkits.axes_grid1 import make_axes_locatable
9+
10+
11+
def plot_mesh(
12+
mesh: dict, bounds: Sequence[float], save_path: str | None = None
13+
) -> None:
14+
"""Plot a 3D triangular mesh with boundary conditions visualization.
15+
16+
Args:
17+
mesh: Dictionary containing 'points' and 'faces' arrays.
18+
save_path: Optional path to save the plot as an image file.
19+
bounds: bounds of the 3D space.
20+
"""
21+
Lx = bounds[0]
22+
Ly = bounds[1]
23+
Lz = bounds[2]
24+
25+
fig = plt.figure(figsize=(10, 8))
26+
ax = fig.add_subplot(111, projection="3d")
27+
ax.plot_trisurf(
28+
mesh["points"][:, 0],
29+
mesh["points"][:, 1],
30+
mesh["points"][:, 2],
31+
triangles=mesh["faces"],
32+
alpha=0.7,
33+
antialiased=True,
34+
color="lightblue",
35+
edgecolor="black",
36+
)
37+
38+
ax.set_xlim(-Lx / 2, Lx / 2)
39+
ax.set_ylim(-Ly / 2, Ly / 2)
40+
ax.set_zlim(-Lz / 2, Lz / 2)
41+
42+
# set equal aspect ratio
43+
ax.set_box_aspect(
44+
(
45+
(Lx) / (Ly),
46+
1,
47+
(Lz) / (Ly),
48+
)
49+
)
50+
51+
# x axis label
52+
ax.set_xlabel("X")
53+
ax.set_ylabel("Y")
54+
ax.set_zlabel("Z")
55+
56+
if save_path:
57+
# avoid showing the plot in notebook
58+
plt.savefig(save_path)
59+
plt.close(fig)
60+
61+
62+
def plot_grid_slice(field_slice, extent, ax, title, xlabel, ylabel):
63+
im = ax.imshow(field_slice.T, extent=extent, origin="lower")
64+
ax.set_title(title)
65+
ax.set_xlabel(xlabel)
66+
ax.set_ylabel(ylabel)
67+
# add colorbar
68+
divider = make_axes_locatable(ax)
69+
cax = divider.append_axes("right", size="5%", pad=0.1)
70+
plt.colorbar(im, cax=cax, orientation="vertical")
71+
return im
72+
73+
74+
def plot_grid(field, Lx, Ly, Lz, Nx, Ny, Nz, title="SDF"):
75+
_, axs = plt.subplots(1, 3, figsize=(15, 5))
76+
77+
plot_grid_slice(
78+
field[Nx // 2, :, :],
79+
extent=(-Ly / 2, Ly / 2, -Lz / 2, Lz / 2),
80+
ax=axs[0],
81+
title=f"{title} slice at x=0",
82+
xlabel="y",
83+
ylabel="z",
84+
)
85+
plot_grid_slice(
86+
field[:, Ny // 2, :],
87+
extent=(-Lx / 2, Lx / 2, -Lz / 2, Lz / 2),
88+
ax=axs[1],
89+
title=f"{title} slice at y=0",
90+
xlabel="x",
91+
ylabel="z",
92+
)
93+
plot_grid_slice(
94+
field[:, :, Nz // 2],
95+
extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2),
96+
ax=axs[2],
97+
title=f"{title} slice at z=0",
98+
xlabel="x",
99+
ylabel="y",
100+
)
101+
102+
103+
T = TypeVar("T")
104+
105+
106+
def stop_grads_int(x: T) -> T:
107+
"""Stops gradient computation.
108+
109+
We cannot use jax.lax.stop_gradient directly because Tesseract meshes are
110+
nested dictionaries with arrays and integers, and jax.lax.stop_gradient
111+
does not support integers.
112+
113+
Args:
114+
x: Input value.
115+
116+
Returns:
117+
Value with stopped gradients.
118+
"""
119+
120+
def stop(x):
121+
return jax._src.ad_util.stop_gradient_p.bind(x)
122+
123+
return jax.tree_util.tree_map(stop, x)
124+
125+
126+
def hex_to_pyvista(
127+
pts: jax.typing.ArrayLike, faces: jax.typing.ArrayLike, cell_data: dict
128+
) -> pv.UnstructuredGrid:
129+
"""Convert hex mesh defined by points and faces into a PyVista UnstructuredGrid.
130+
131+
Args:
132+
pts: Array of point coordinates, shape (N, 3).
133+
faces: Array of hexahedral cell connectivity, shape (M, 8).
134+
cell_data: additional cell center data.
135+
136+
Returns:
137+
PyVista mesh representing the hexahedral grid.
138+
"""
139+
pts = np.array(pts)
140+
faces = np.array(faces)
141+
142+
# Define the cell type for hexahedrons (VTK_HEXAHEDRON = 12)
143+
cell_type = pv.CellType.HEXAHEDRON
144+
cell_types = np.array([cell_type] * faces.shape[0], dtype=np.uint8)
145+
146+
# Prepare the cells array: [number_of_points, i0, i1, i2, i3, i4, i5, i6, i7]
147+
n_cells = faces.shape[0]
148+
cells = np.empty((n_cells, 9), dtype=np.int64)
149+
cells[:, 0] = 8 # Each cell has 8 points
150+
cells[:, 1:9] = faces
151+
152+
# Flatten the cells array for PyVista
153+
cells = cells.flatten()
154+
155+
mesh = pv.UnstructuredGrid(cells, cell_types, pts)
156+
157+
# Add cell data
158+
for name, data in cell_data.items():
159+
mesh.cell_data[name] = data
160+
161+
return mesh

0 commit comments

Comments
 (0)