Skip to content

Commit 8ce9be4

Browse files
committed
working beam optim
1 parent e01d226 commit 8ce9be4

File tree

9 files changed

+2860
-3794
lines changed

9 files changed

+2860
-3794
lines changed

examples/ansys/bars_mesh.vtk

158 KB
Binary file not shown.

examples/ansys/gf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
images = []
44

5-
for i in range(3):
5+
for i in range(10):
66
filename = f"tmp_img/mesh_optim_{i:03d}.png"
77
images.append(imageio.imread(filename))
88
print(f"Added {filename} to gif.")

examples/ansys/mesh_optim.gif

-78 KB
Loading

examples/ansys/meshing_tess/tesseract_api.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,39 @@ class OutputSchema(BaseModel):
8282
#
8383
# Helper functions
8484
#
85-
def create_single_hex(
86-
Lx: float,
87-
Ly: float,
88-
Lz: float,
85+
86+
87+
def hex_grid(
88+
Lx: float, Ly: float, Lz: float, Nx: int, Ny: int, Nz: int
8989
) -> tuple[jnp.ndarray, jnp.ndarray]:
90-
"""Create a single HEX8 mesh of a cuboid domain."""
91-
# Define the 8 corner points of the hexahedron
92-
points = jnp.array(
93-
[
94-
[-Lx / 2, -Ly / 2, -Lz / 2], # Point 0
95-
[Lx / 2, -Ly / 2, -Lz / 2], # Point 1
96-
[Lx / 2, Ly / 2, -Lz / 2], # Point 2
97-
[-Lx / 2, Ly / 2, -Lz / 2], # Point 3
98-
[-Lx / 2, -Ly / 2, Lz / 2], # Point 4
99-
[Lx / 2, -Ly / 2, Lz / 2], # Point 5
100-
[Lx / 2, Ly / 2, Lz / 2], # Point 6
101-
[-Lx / 2, Ly / 2, Lz / 2], # Point 7
102-
],
103-
dtype=jnp.float32,
104-
)
90+
"""Creates a hex mesh with Nx * Ny * Nz points.
10591
106-
# Define the hexahedron cell using the point indices
107-
hex_cells = jnp.array(
108-
[
109-
[0, 1, 2, 3, 4, 5, 6, 7] # Single HEX8 element
110-
],
111-
dtype=jnp.int32,
112-
)
92+
This is (Nx-1) * (Ny-1) * (Nz-1) cells
93+
"""
94+
xs = jnp.linspace(-Lx / 2, Lx / 2, Nx)
95+
ys = jnp.linspace(-Ly / 2, Ly / 2, Ny)
96+
zs = jnp.linspace(-Lz / 2, Lz / 2, Nz)
11397

114-
return points, hex_cells
98+
xs, ys, zs = jnp.meshgrid(xs, ys, zs, indexing="ij")
99+
100+
pts = jnp.stack((xs, ys, zs), -1)
101+
102+
points_inds = jnp.arange(Nx * Ny * Nz)
103+
points_inds_xyz = points_inds.reshape(Nx, Ny, Nz)
104+
inds1 = points_inds_xyz[:-1, :-1, :-1]
105+
inds2 = points_inds_xyz[1:, :-1, :-1]
106+
inds3 = points_inds_xyz[1:, 1:, :-1]
107+
inds4 = points_inds_xyz[:-1, 1:, :-1]
108+
inds5 = points_inds_xyz[:-1, :-1, 1:]
109+
inds6 = points_inds_xyz[1:, :-1, 1:]
110+
inds7 = points_inds_xyz[1:, 1:, 1:]
111+
inds8 = points_inds_xyz[:-1, 1:, 1:]
112+
113+
cells = jnp.stack(
114+
(inds1, inds2, inds3, inds4, inds5, inds6, inds7, inds8), axis=-1
115+
).reshape(-1, 8)
116+
117+
return pts.reshape(-1, 3), cells
115118

116119

117120
def vectorized_subdivide_hex_mesh(
@@ -169,9 +172,9 @@ def vectorized_subdivide_hex_mesh(
169172
cell_offsets = cell_offsets.at[0, index].set(
170173
jnp.array(
171174
[
172-
(0.25 - ix * 0.5) if split_x else 0.0,
173-
(0.25 - iy * 0.5) if split_y else 0.0,
174-
(0.25 - iz * 0.5) if split_z else 0.0,
175+
(-0.25 + ix * 0.5) if split_x else 0.0,
176+
(-0.25 + iy * 0.5) if split_y else 0.0,
177+
(-0.25 + iz * 0.5) if split_z else 0.0,
175178
]
176179
).T
177180
)
@@ -204,7 +207,7 @@ def vectorized_subdivide_hex_mesh(
204207
# Repeat the point offsets and scale them by the corresponding hex sizes
205208
# -> point_offset tensor of shape (n_hex_to_subdiv, n_points_per_hex, 3)
206209
point_offsets = point_offsets.reshape((1, n_points_per_hex, 3)).repeat(
207-
hex_sizes.shape[0], axis=0
210+
n_hex_to_subdiv, axis=0
208211
) * hex_sizes.reshape((n_hex_to_subdiv, 1, 3)).repeat(n_points_per_hex, axis=1)
209212

210213
# Repeat the two offsets at an additional axis to get all combinations
@@ -229,7 +232,7 @@ def vectorized_subdivide_hex_mesh(
229232
# Directly compute new point coordinates and reshape
230233
new_pts_coords = (center_points + total_offsets).reshape((n_new_pts, 3))
231234
# Compute new hex cell indices
232-
new_hex_cells = jnp.linspace(0, n_new_pts - 1, n_new_pts, dtype=jnp.int32).reshape(
235+
new_hex_cells = jnp.arange(n_new_pts, dtype=jnp.int32).reshape(
233236
(n_new_cells, n_points_per_hex)
234237
)
235238

@@ -267,10 +270,17 @@ def remove_duplicate_points(
267270
pts_coords: jnp.ndarray, hex_cells: jnp.ndarray
268271
) -> tuple[jnp.ndarray, jnp.ndarray]:
269272
"""Remove duplicate points from the mesh and update hex cell indices."""
270-
unique_pts, inverse_indices = jnp.unique(pts_coords, axis=0, return_inverse=True)
271-
updated_hex_cells = inverse_indices[hex_cells]
273+
# TODO: remove rounding after removing duplicate points
274+
pts_coords = jnp.round(pts_coords, decimals=5)
275+
_, indices, inverse_indices = jnp.unique(
276+
pts_coords, axis=0, return_index=True, return_inverse=True
277+
)
272278

273-
return unique_pts, updated_hex_cells
279+
pts_coords = pts_coords[indices]
280+
281+
hex_cells = inverse_indices[hex_cells]
282+
283+
return pts_coords, hex_cells
274284

275285

276286
def recursive_subdivide_hex_mesh(
@@ -368,7 +378,14 @@ def generate_mesh(
368378
points: (n_points, 3) array of vertex positions.
369379
hex_cells: (n_hex, 8) array of hexahedron cell indices.
370380
"""
371-
initial_pts, initial_hex_cells = create_single_hex(Lx, Ly, Lz)
381+
# get largest cell size
382+
max_size = jnp.max(sizing_field)
383+
384+
Nx = int(Lx / max_size)
385+
Ny = int(Ly / max_size)
386+
Nz = int(Lz / max_size)
387+
388+
initial_pts, initial_hex_cells = hex_grid(Lx, Ly, Lz, Nx, Ny, Nz)
372389

373390
pts, cells = recursive_subdivide_hex_mesh(
374391
initial_hex_cells,

examples/ansys/optim_bars.ipynb

Lines changed: 2806 additions & 3757 deletions
Large diffs are not rendered by default.

examples/ansys/rho_optim_sum_2.gif

383 KB
Loading
-400 KB
Binary file not shown.
-400 KB
Binary file not shown.
-400 KB
Binary file not shown.

0 commit comments

Comments
 (0)