@@ -82,36 +82,39 @@ class OutputSchema(BaseModel):
8282#
8383# Helper functions
8484#
85- def create_single_hex (
86- Lx : float ,
87- Ly : float ,
88- Lz : float ,
85+
86+
87+ def hex_grid (
88+ Lx : float , Ly : float , Lz : float , Nx : int , Ny : int , Nz : int
8989) -> tuple [jnp .ndarray , jnp .ndarray ]:
90- """Create a single HEX8 mesh of a cuboid domain."""
91- # Define the 8 corner points of the hexahedron
92- points = jnp .array (
93- [
94- [- Lx / 2 , - Ly / 2 , - Lz / 2 ], # Point 0
95- [Lx / 2 , - Ly / 2 , - Lz / 2 ], # Point 1
96- [Lx / 2 , Ly / 2 , - Lz / 2 ], # Point 2
97- [- Lx / 2 , Ly / 2 , - Lz / 2 ], # Point 3
98- [- Lx / 2 , - Ly / 2 , Lz / 2 ], # Point 4
99- [Lx / 2 , - Ly / 2 , Lz / 2 ], # Point 5
100- [Lx / 2 , Ly / 2 , Lz / 2 ], # Point 6
101- [- Lx / 2 , Ly / 2 , Lz / 2 ], # Point 7
102- ],
103- dtype = jnp .float32 ,
104- )
90+ """Creates a hex mesh with Nx * Ny * Nz points.
10591
106- # Define the hexahedron cell using the point indices
107- hex_cells = jnp .array (
108- [
109- [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] # Single HEX8 element
110- ],
111- dtype = jnp .int32 ,
112- )
92+ This is (Nx-1) * (Ny-1) * (Nz-1) cells
93+ """
94+ xs = jnp .linspace (- Lx / 2 , Lx / 2 , Nx )
95+ ys = jnp .linspace (- Ly / 2 , Ly / 2 , Ny )
96+ zs = jnp .linspace (- Lz / 2 , Lz / 2 , Nz )
11397
114- return points , hex_cells
98+ xs , ys , zs = jnp .meshgrid (xs , ys , zs , indexing = "ij" )
99+
100+ pts = jnp .stack ((xs , ys , zs ), - 1 )
101+
102+ points_inds = jnp .arange (Nx * Ny * Nz )
103+ points_inds_xyz = points_inds .reshape (Nx , Ny , Nz )
104+ inds1 = points_inds_xyz [:- 1 , :- 1 , :- 1 ]
105+ inds2 = points_inds_xyz [1 :, :- 1 , :- 1 ]
106+ inds3 = points_inds_xyz [1 :, 1 :, :- 1 ]
107+ inds4 = points_inds_xyz [:- 1 , 1 :, :- 1 ]
108+ inds5 = points_inds_xyz [:- 1 , :- 1 , 1 :]
109+ inds6 = points_inds_xyz [1 :, :- 1 , 1 :]
110+ inds7 = points_inds_xyz [1 :, 1 :, 1 :]
111+ inds8 = points_inds_xyz [:- 1 , 1 :, 1 :]
112+
113+ cells = jnp .stack (
114+ (inds1 , inds2 , inds3 , inds4 , inds5 , inds6 , inds7 , inds8 ), axis = - 1
115+ ).reshape (- 1 , 8 )
116+
117+ return pts .reshape (- 1 , 3 ), cells
115118
116119
117120def vectorized_subdivide_hex_mesh (
@@ -169,9 +172,9 @@ def vectorized_subdivide_hex_mesh(
169172 cell_offsets = cell_offsets .at [0 , index ].set (
170173 jnp .array (
171174 [
172- (0.25 - ix * 0.5 ) if split_x else 0.0 ,
173- (0.25 - iy * 0.5 ) if split_y else 0.0 ,
174- (0.25 - iz * 0.5 ) if split_z else 0.0 ,
175+ (- 0.25 + ix * 0.5 ) if split_x else 0.0 ,
176+ (- 0.25 + iy * 0.5 ) if split_y else 0.0 ,
177+ (- 0.25 + iz * 0.5 ) if split_z else 0.0 ,
175178 ]
176179 ).T
177180 )
@@ -204,7 +207,7 @@ def vectorized_subdivide_hex_mesh(
204207 # Repeat the point offsets and scale them by the corresponding hex sizes
205208 # -> point_offset tensor of shape (n_hex_to_subdiv, n_points_per_hex, 3)
206209 point_offsets = point_offsets .reshape ((1 , n_points_per_hex , 3 )).repeat (
207- hex_sizes . shape [ 0 ] , axis = 0
210+ n_hex_to_subdiv , axis = 0
208211 ) * hex_sizes .reshape ((n_hex_to_subdiv , 1 , 3 )).repeat (n_points_per_hex , axis = 1 )
209212
210213 # Repeat the two offsets at an additional axis to get all combinations
@@ -229,7 +232,7 @@ def vectorized_subdivide_hex_mesh(
229232 # Directly compute new point coordinates and reshape
230233 new_pts_coords = (center_points + total_offsets ).reshape ((n_new_pts , 3 ))
231234 # Compute new hex cell indices
232- new_hex_cells = jnp .linspace ( 0 , n_new_pts - 1 , n_new_pts , dtype = jnp .int32 ).reshape (
235+ new_hex_cells = jnp .arange ( n_new_pts , dtype = jnp .int32 ).reshape (
233236 (n_new_cells , n_points_per_hex )
234237 )
235238
@@ -267,10 +270,17 @@ def remove_duplicate_points(
267270 pts_coords : jnp .ndarray , hex_cells : jnp .ndarray
268271) -> tuple [jnp .ndarray , jnp .ndarray ]:
269272 """Remove duplicate points from the mesh and update hex cell indices."""
270- unique_pts , inverse_indices = jnp .unique (pts_coords , axis = 0 , return_inverse = True )
271- updated_hex_cells = inverse_indices [hex_cells ]
273+ # TODO: remove rounding after removing duplicate points
274+ pts_coords = jnp .round (pts_coords , decimals = 5 )
275+ _ , indices , inverse_indices = jnp .unique (
276+ pts_coords , axis = 0 , return_index = True , return_inverse = True
277+ )
272278
273- return unique_pts , updated_hex_cells
279+ pts_coords = pts_coords [indices ]
280+
281+ hex_cells = inverse_indices [hex_cells ]
282+
283+ return pts_coords , hex_cells
274284
275285
276286def recursive_subdivide_hex_mesh (
@@ -368,7 +378,14 @@ def generate_mesh(
368378 points: (n_points, 3) array of vertex positions.
369379 hex_cells: (n_hex, 8) array of hexahedron cell indices.
370380 """
371- initial_pts , initial_hex_cells = create_single_hex (Lx , Ly , Lz )
381+ # get largest cell size
382+ max_size = jnp .max (sizing_field )
383+
384+ Nx = int (Lx / max_size )
385+ Ny = int (Ly / max_size )
386+ Nz = int (Lz / max_size )
387+
388+ initial_pts , initial_hex_cells = hex_grid (Lx , Ly , Lz , Nx , Ny , Nz )
372389
373390 pts , cells = recursive_subdivide_hex_mesh (
374391 initial_hex_cells ,
0 commit comments