Skip to content

Commit 20d3f5d

Browse files
committed
removed loops from hex mesher
1 parent e183976 commit 20d3f5d

File tree

7 files changed

+361
-4995
lines changed

7 files changed

+361
-4995
lines changed

examples/ansys/gf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
images = []
44

5-
for i in range(40):
5+
for i in range(20):
66
filename = f"tmp_img/mesh_optim_{i:03d}.png"
77
images.append(imageio.imread(filename))
88
print(f"Added {filename} to gif.")

examples/ansys/mesh_optim.gif

-715 KB
Loading

examples/ansys/meshing_tess/tesseract_api.py

Lines changed: 77 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -141,95 +141,95 @@ def vectorized_subdivide_hex_mesh(
141141
|/____ x
142142
143143
"""
144-
n_hex_subd = mask.sum()
145-
n_hex_each = (split_x + 1) * (split_y + 1) * (split_z + 1)
146-
n_new_pts = (
147-
8 * n_hex_each
148-
) * n_hex_subd # 8 corners per new hex, 8 new hexes per old hex
149-
n_new_cells = n_hex_each * n_hex_subd
144+
# compute sizes
145+
n_hex_to_subdiv = mask.sum()
146+
n_hex_each_hex = (split_x + 1) * (split_y + 1) * (split_z + 1)
147+
n_points_per_hex = 8
148+
# 8 corners per new hex, 8 new hexes per old hex
149+
n_new_pts = n_points_per_hex * n_hex_each_hex * n_hex_to_subdiv
150+
n_new_cells = n_hex_each_hex * n_hex_to_subdiv
150151

151152
new_pts_coords = jnp.zeros((n_new_pts, 3), dtype=pts_coords.dtype)
152153
new_hex_cells = jnp.zeros((n_new_cells, 8), dtype=hex_cells.dtype)
153154

154-
voxel_sizes = jnp.abs(
155-
pts_coords[hex_cells[mask, 6]] - pts_coords[hex_cells[mask, 0]]
156-
)
157-
158-
center_points = jnp.mean(pts_coords[hex_cells[mask]], axis=1) # (n_hex, 3)
155+
# get sizes of hexes to subdivide
156+
hex_sizes = jnp.abs(pts_coords[hex_cells[mask, 6]] - pts_coords[hex_cells[mask, 0]])
157+
# Ceneter points of shape (n_hex_to_subdiv, 3)
158+
center_points = jnp.mean(pts_coords[hex_cells[mask]], axis=1)
159159

160-
cell_offsets = jnp.zeros((1, n_hex_each, 3), dtype=jnp.float32)
160+
# Build cell offset tensor
161+
# that is the offset of a hex center to each of the new hex centers
162+
cell_offsets = jnp.zeros((1, n_hex_each_hex, 3), dtype=jnp.float32)
161163
index = 0
162164
for ix in range(split_x + 1):
163165
for iy in range(split_y + 1):
164166
for iz in range(split_z + 1):
165167
cell_offsets = cell_offsets.at[0, index].set(
166168
jnp.array(
167169
[
168-
(ix * 0.5 - 0.5) if split_x else 0.0,
169-
(iy * 0.5 - 0.5) if split_y else 0.0,
170-
(iz * 0.5 - 0.5) if split_z else 0.0,
170+
(0.25 - ix * 0.5) if split_x else 0.0,
171+
(0.25 - iy * 0.5) if split_y else 0.0,
172+
(0.25 - iz * 0.5) if split_z else 0.0,
171173
]
172174
).T
173175
)
174176
index += 1
175177

176-
cell_offsets = cell_offsets.repeat(
177-
voxel_sizes.shape[0], axis=0
178-
) * voxel_sizes.reshape((n_hex_subd, 1, 3)).repeat(n_hex_each, axis=1)
179-
180-
offsets = jnp.array(
178+
# We now repeat the cell offsets and scale them by the corresponding hex sizes
179+
# Hence we have a cell_offset tensor of shape (n_hex_to_subdiv, n_hex_each_hex, 3)
180+
cell_offsets = cell_offsets.repeat(n_hex_to_subdiv, axis=0) * hex_sizes.reshape(
181+
(n_hex_to_subdiv, 1, 3)
182+
).repeat(n_hex_each_hex, axis=1)
183+
184+
# Build point offset tensor
185+
# that is the offset of a hex center to each of the new hex points
186+
offset_x = 0.25 if split_x else 0.5
187+
offset_y = 0.25 if split_y else 0.5
188+
offset_z = 0.25 if split_z else 0.5
189+
point_offsets = jnp.array(
181190
[
182-
[-0.25, -0.25, -0.25],
183-
[0.25, -0.25, -0.25],
184-
[0.25, 0.25, -0.25],
185-
[-0.25, 0.25, -0.25],
186-
[-0.25, -0.25, 0.25],
187-
[0.25, -0.25, 0.25],
188-
[0.25, 0.25, 0.25],
189-
[-0.25, 0.25, 0.25],
191+
[-offset_x, -offset_y, -offset_z],
192+
[offset_x, -offset_y, -offset_z],
193+
[offset_x, offset_y, -offset_z],
194+
[-offset_x, offset_y, -offset_z],
195+
[-offset_x, -offset_y, offset_z],
196+
[offset_x, -offset_y, offset_z],
197+
[offset_x, offset_y, offset_z],
198+
[-offset_x, offset_y, offset_z],
190199
]
191-
).reshape((1, 8, 3)).repeat(voxel_sizes.shape[0], axis=0) * voxel_sizes.reshape(
192-
(n_hex_subd, 1, 3)
193-
).repeat(8, axis=1)
194-
195-
for cell in range(n_hex_each):
196-
center = center_points + cell_offsets[:, cell]
197-
198-
for corner in range(8):
199-
new_pts_coords = new_pts_coords.at[
200-
jnp.arange(n_hex_subd) * 8 * n_hex_each + cell * n_hex_each + corner
201-
].set(center + offsets[:, corner])
202-
203-
new_hex_cells = new_hex_cells.at[
204-
jnp.arange(n_hex_subd) * n_hex_each + cell, corner
205-
].set(jnp.arange(n_hex_subd) * 8 * n_hex_each + cell * n_hex_each + corner)
206-
207-
# offsets = jnp.array(
208-
# [
209-
# [-0.25, -0.25, -0.25],
210-
# [0.25, -0.25, -0.25],
211-
# [0.25, 0.25, -0.25],
212-
# [-0.25, 0.25, -0.25],
213-
# [-0.25, -0.25, 0.25],
214-
# [0.25, -0.25, 0.25],
215-
# [0.25, 0.25, 0.25],
216-
# [-0.25, 0.25, 0.25],
217-
# ]
218-
# ).reshape((1, 8, 3)).repeat(voxel_sizes.shape[0], axis=0) * voxel_sizes.reshape(
219-
# (n_hex_new, 1, 3)
220-
# ).repeat(8, axis=1)
221-
222-
# for cell in range(8):
223-
# center = center_points + offsets[:, cell]
224-
225-
# for corner in range(8):
226-
# new_pts_coords = new_pts_coords.at[
227-
# jnp.arange(n_hex_new) * 64 + cell * 8 + corner
228-
# ].set(center + offsets[:, corner])
229-
230-
# new_hex_cells = new_hex_cells.at[
231-
# jnp.arange(n_hex_new) * 8 + cell, corner
232-
# ].set(jnp.arange(n_hex_new) * 64 + cell * 8 + corner)
200+
)
201+
202+
# Repeat the point offsets and scale them by the corresponding hex sizes
203+
# -> point_offset tensor of shape (n_hex_to_subdiv, n_points_per_hex, 3)
204+
point_offsets = point_offsets.reshape((1, n_points_per_hex, 3)).repeat(
205+
hex_sizes.shape[0], axis=0
206+
) * hex_sizes.reshape((n_hex_to_subdiv, 1, 3)).repeat(n_points_per_hex, axis=1)
207+
208+
# Repeat the two offsets at an additional axis to get all combinations
209+
cell_offsets = cell_offsets.reshape((n_hex_to_subdiv, n_hex_each_hex, 1, 3)).repeat(
210+
n_points_per_hex, axis=2
211+
)
212+
point_offsets = point_offsets.reshape(
213+
(n_hex_to_subdiv, 1, n_points_per_hex, 3)
214+
).repeat(n_hex_each_hex, axis=1)
215+
216+
# Compute total offset relative to old hex center
217+
# -> (n_hex_to_subdiv, n_hex_each_hex, n_points_per_hex, 3)
218+
total_offsets = cell_offsets + point_offsets
219+
220+
# lets reshape the center points to broadcast
221+
center_points = (
222+
center_points.reshape((n_hex_to_subdiv, 1, 1, 3))
223+
.repeat(n_hex_each_hex, axis=1)
224+
.repeat(n_points_per_hex, axis=2)
225+
)
226+
227+
# Directly compute new point coordinates and reshape
228+
new_pts_coords = (center_points + total_offsets).reshape((n_new_pts, 3))
229+
# Compute new hex cell indices
230+
new_hex_cells = jnp.linspace(0, n_new_pts - 1, n_new_pts, dtype=jnp.int32).reshape(
231+
(n_new_cells, n_points_per_hex)
232+
)
233233

234234
def reindex_and_mask(
235235
coords: jnp.ndarray, cells: jnp.ndarray, keep_mask: jnp.ndarray
@@ -249,9 +249,6 @@ def reindex_and_mask(
249249

250250
return coords, cells
251251

252-
# new_pts_coords, new_hex_cells = reindex_and_mask(
253-
# new_pts_coords, new_hex_cells, mask.repeat(8)
254-
# )
255252
old_pts_coords, old_hex_cells = reindex_and_mask(
256253
pts_coords, hex_cells, jnp.logical_not(mask)
257254
)
@@ -297,21 +294,22 @@ def recursive_subdivide_hex_mesh(
297294
Returns:
298295
Subdivided points and hex cells.
299296
"""
300-
# lets build the kd-tree for fast nearest neighbor search
301297
xs = jnp.linspace(-Lx / 2, Lx / 2, sizing_field.shape[0])
302298
ys = jnp.linspace(-Ly / 2, Ly / 2, sizing_field.shape[1])
303299
zs = jnp.linspace(-Lz / 2, Lz / 2, sizing_field.shape[2])
304300

305301
interpolator = RegularGridInterpolator(
306-
(xs, ys, zs), sizing_field, method="linear", bounds_error=False, fill_value=-1
302+
(xs, ys, zs), sizing_field, method="nearest", bounds_error=False, fill_value=-1
307303
)
308304

309305
for i in range(levels):
310306
voxel_sizes = jnp.abs(pts_coords[hex_cells[:, 6]] - pts_coords[hex_cells[:, 0]])
311307

312-
voxel_center_points = jnp.mean(pts_coords[hex_cells], axis=1)
313-
sizing_values = interpolator(voxel_center_points)
314-
subdivision_mask = jnp.max(voxel_sizes, axis=-1) > sizing_values
308+
# voxel_center_points = jnp.mean(pts_coords[hex_cells], axis=1)
309+
sizing_values_pts = interpolator(pts_coords)
310+
voxel_sizing_min = jnp.min(sizing_values_pts[hex_cells], axis=1)
311+
312+
subdivision_mask = jnp.max(voxel_sizes, axis=-1) > voxel_sizing_min
315313

316314
if not jnp.any(subdivision_mask):
317315
print(f"No more subdivisions needed at level {i}.")
@@ -421,6 +419,7 @@ def apply(inputs: InputSchema) -> OutputSchema:
421419
)
422420

423421
cell_centers = jnp.mean(pts[cells], axis=1)
422+
424423
cell_values = interpolator(cell_centers)
425424

426425
cell_values_padded = jnp.zeros((inputs.max_cells,), dtype=cell_values.dtype)

examples/ansys/optim_bars.ipynb

Lines changed: 282 additions & 4907 deletions
Large diffs are not rendered by default.

examples/ansys/optim_grid.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
},
6868
{
6969
"cell_type": "code",
70-
"execution_count": 51,
70+
"execution_count": null,
7171
"id": "64ebfb56",
7272
"metadata": {},
7373
"outputs": [],

examples/ansys/rho_optim_x.gif

-1.25 MB
Loading

examples/ansys/sdf_fd_tess/tesseract_api.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,6 @@ def get_geometry(
108108
109109
The parameters are expected to be of shape (n_chains, n_edges_per_chain + 1, 3),
110110
"""
111-
print(
112-
{
113-
"differentiable_parameters": differentiable_parameters,
114-
"non_differentiable_parameters": non_differentiable_parameters,
115-
"static_parameters": static_parameters,
116-
"string_parameters": string_parameters,
117-
}
118-
)
119111
mesh = target.apply(
120112
{
121113
"differentiable_parameters": differentiable_parameters,

0 commit comments

Comments
 (0)