Skip to content

Commit 93d6a4e

Browse files
committed
increase mesher speed
1 parent 759cc6d commit 93d6a4e

File tree

4 files changed

+348
-55
lines changed

4 files changed

+348
-55
lines changed

examples/ansys/demo_2.ipynb

Lines changed: 317 additions & 21 deletions
Large diffs are not rendered by default.
-23.4 KB
Binary file not shown.

examples/ansys/hot_design_tess/tesseract_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ def vector_jacobian_product(
318318
# Reduce the cotangent vector to the shape of the Jacobian, to compute VJP by hand
319319
vjp = np.einsum("klmn,lmn->k", jac, cotangent_vector["sdf"]).astype(np.float32)
320320

321-
print(vjp.shape)
322321
return {"differentiable_parameters": vjp}
323322

324323

examples/ansys/meshing_tess/tesseract_api.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def vectorized_subdivide_hex_mesh(
131131
This method introduces duplicates of points that should later be merged.
132132
133133
Hexahedron is constructed as follows:
134+
134135
3 -------- 2
135136
/| /|
136137
7 -------- 6 |
@@ -140,23 +141,26 @@ def vectorized_subdivide_hex_mesh(
140141
4 -------- 5
141142
142143
Axis orientation:
143-
y
144-
|
145-
|____ x
146-
/
147-
/
148-
z
144+
145+
y
146+
|
147+
|____ x
148+
/
149+
/
150+
z
149151
150152
"""
151-
n_hex = hex_cells.shape[0]
152-
n_new_pts = (8 * 8) * n_hex # 8 corners per new hex, 8 new hexes per old hex
153+
n_hex_new = mask.sum()
154+
n_new_pts = (8 * 8) * n_hex_new # 8 corners per new hex, 8 new hexes per old hex
153155

154156
new_pts_coords = jnp.zeros((n_new_pts, 3), dtype=pts_coords.dtype)
155-
new_hex_cells = jnp.zeros((n_hex * 8, 8), dtype=hex_cells.dtype)
157+
new_hex_cells = jnp.zeros((n_hex_new * 8, 8), dtype=hex_cells.dtype)
156158

157-
voxel_sizes = jnp.abs(pts_coords[hex_cells[:, 6]] - pts_coords[hex_cells[:, 0]])
159+
voxel_sizes = jnp.abs(
160+
pts_coords[hex_cells[mask, 6]] - pts_coords[hex_cells[mask, 0]]
161+
)
158162

159-
center_points = jnp.mean(pts_coords[hex_cells], axis=1) # (n_hex, 3)
163+
center_points = jnp.mean(pts_coords[hex_cells[mask]], axis=1) # (n_hex, 3)
160164
offsets = jnp.array(
161165
[
162166
[-0.25, -0.25, -0.25],
@@ -169,20 +173,20 @@ def vectorized_subdivide_hex_mesh(
169173
[-0.25, 0.25, 0.25],
170174
]
171175
).reshape((1, 8, 3)).repeat(voxel_sizes.shape[0], axis=0) * voxel_sizes.reshape(
172-
(n_hex, 1, 3)
176+
(n_hex_new, 1, 3)
173177
).repeat(8, axis=1)
174178

175179
for cell in range(8):
176180
center = center_points + offsets[:, cell]
177181

178182
for corner in range(8):
179183
new_pts_coords = new_pts_coords.at[
180-
jnp.arange(n_hex) * 64 + cell * 8 + corner
184+
jnp.arange(n_hex_new) * 64 + cell * 8 + corner
181185
].set(center - offsets[:, corner])
182186

183-
new_hex_cells = new_hex_cells.at[jnp.arange(n_hex) * 8 + cell, corner].set(
184-
jnp.arange(n_hex) * 64 + cell * 8 + corner
185-
)
187+
new_hex_cells = new_hex_cells.at[
188+
jnp.arange(n_hex_new) * 8 + cell, corner
189+
].set(jnp.arange(n_hex_new) * 64 + cell * 8 + corner)
186190

187191
def reindex_and_mask(
188192
coords: jnp.ndarray, cells: jnp.ndarray, keep_mask: jnp.ndarray
@@ -202,9 +206,9 @@ def reindex_and_mask(
202206

203207
return coords, cells
204208

205-
new_pts_coords, new_hex_cells = reindex_and_mask(
206-
new_pts_coords, new_hex_cells, mask.repeat(8)
207-
)
209+
# new_pts_coords, new_hex_cells = reindex_and_mask(
210+
# new_pts_coords, new_hex_cells, mask.repeat(8)
211+
# )
208212
old_pts_coords, old_hex_cells = reindex_and_mask(
209213
pts_coords, hex_cells, jnp.logical_not(mask)
210214
)
@@ -279,7 +283,6 @@ def recursive_subdivide_hex_mesh(
279283

280284
return pts_coords, hex_cells
281285

282-
mesh = None # cache for the last generated mesh
283286

284287
# @lru_cache(maxsize=1)
285288
def generate_mesh(
@@ -314,8 +317,6 @@ def generate_mesh(
314317
Lz=Lz,
315318
)
316319

317-
mesh = (pts, cells) # cache the generated mesh
318-
319320
return pts, cells
320321

321322

@@ -397,17 +398,14 @@ def vector_jacobian_product(
397398
assert vjp_inputs == {"field_values"}
398399
assert vjp_outputs == {"mesh_cell_values"}
399400

400-
if mesh is None:
401-
pts, cells = generate_mesh(
402-
Lx=inputs.Lx,
403-
Ly=inputs.Ly,
404-
Lz=inputs.Lz,
405-
sizing_field=inputs.sizing_field,
406-
max_levels=inputs.max_subdivision_levels,
407-
)
408-
else:
409-
print("Using cached mesh for VJP computation.")
410-
pts, cells = mesh
401+
pts, cells = generate_mesh(
402+
Lx=inputs.Lx,
403+
Ly=inputs.Ly,
404+
Lz=inputs.Lz,
405+
sizing_field=inputs.sizing_field,
406+
max_levels=inputs.max_subdivision_levels,
407+
)
408+
411409
cell_centers = jnp.mean(pts[cells], axis=1)
412410

413411
xs = jnp.linspace(-inputs.Lx / 2, inputs.Lx / 2, inputs.field_values.shape[0])

0 commit comments

Comments
 (0)