@@ -141,95 +141,95 @@ def vectorized_subdivide_hex_mesh(
141141 |/____ x
142142
143143 """
144- n_hex_subd = mask .sum ()
145- n_hex_each = (split_x + 1 ) * (split_y + 1 ) * (split_z + 1 )
146- n_new_pts = (
147- 8 * n_hex_each
148- ) * n_hex_subd # 8 corners per new hex, 8 new hexes per old hex
149- n_new_cells = n_hex_each * n_hex_subd
144+ # compute sizes
145+ n_hex_to_subdiv = mask .sum ()
146+ n_hex_each_hex = (split_x + 1 ) * (split_y + 1 ) * (split_z + 1 )
147+ n_points_per_hex = 8
148+ # 8 corners per new hex, 8 new hexes per old hex
149+ n_new_pts = n_points_per_hex * n_hex_each_hex * n_hex_to_subdiv
150+ n_new_cells = n_hex_each_hex * n_hex_to_subdiv
150151
151152 new_pts_coords = jnp .zeros ((n_new_pts , 3 ), dtype = pts_coords .dtype )
152153 new_hex_cells = jnp .zeros ((n_new_cells , 8 ), dtype = hex_cells .dtype )
153154
154- voxel_sizes = jnp .abs (
155- pts_coords [hex_cells [mask , 6 ]] - pts_coords [hex_cells [mask , 0 ]]
156- )
157-
158- center_points = jnp .mean (pts_coords [hex_cells [mask ]], axis = 1 ) # (n_hex, 3)
155+ # get sizes of hexes to subdivide
156+ hex_sizes = jnp .abs (pts_coords [hex_cells [mask , 6 ]] - pts_coords [hex_cells [mask , 0 ]])
157+ # Ceneter points of shape (n_hex_to_subdiv, 3)
158+ center_points = jnp .mean (pts_coords [hex_cells [mask ]], axis = 1 )
159159
160- cell_offsets = jnp .zeros ((1 , n_hex_each , 3 ), dtype = jnp .float32 )
160+ # Build cell offset tensor
161+ # that is the offset of a hex center to each of the new hex centers
162+ cell_offsets = jnp .zeros ((1 , n_hex_each_hex , 3 ), dtype = jnp .float32 )
161163 index = 0
162164 for ix in range (split_x + 1 ):
163165 for iy in range (split_y + 1 ):
164166 for iz in range (split_z + 1 ):
165167 cell_offsets = cell_offsets .at [0 , index ].set (
166168 jnp .array (
167169 [
168- (ix * 0.5 - 0.5 ) if split_x else 0.0 ,
169- (iy * 0.5 - 0.5 ) if split_y else 0.0 ,
170- (iz * 0.5 - 0.5 ) if split_z else 0.0 ,
170+ (0.25 - ix * 0.5 ) if split_x else 0.0 ,
171+ (0.25 - iy * 0.5 ) if split_y else 0.0 ,
172+ (0.25 - iz * 0.5 ) if split_z else 0.0 ,
171173 ]
172174 ).T
173175 )
174176 index += 1
175177
176- cell_offsets = cell_offsets .repeat (
177- voxel_sizes .shape [0 ], axis = 0
178- ) * voxel_sizes .reshape ((n_hex_subd , 1 , 3 )).repeat (n_hex_each , axis = 1 )
179-
180- offsets = jnp .array (
178+ # We now repeat the cell offsets and scale them by the corresponding hex sizes
179+ # Hence we have a cell_offset tensor of shape (n_hex_to_subdiv, n_hex_each_hex, 3)
180+ cell_offsets = cell_offsets .repeat (n_hex_to_subdiv , axis = 0 ) * hex_sizes .reshape (
181+ (n_hex_to_subdiv , 1 , 3 )
182+ ).repeat (n_hex_each_hex , axis = 1 )
183+
184+ # Build point offset tensor
185+ # that is the offset of a hex center to each of the new hex points
186+ offset_x = 0.25 if split_x else 0.5
187+ offset_y = 0.25 if split_y else 0.5
188+ offset_z = 0.25 if split_z else 0.5
189+ point_offsets = jnp .array (
181190 [
182- [- 0.25 , - 0.25 , - 0.25 ],
183- [0.25 , - 0.25 , - 0.25 ],
184- [0.25 , 0.25 , - 0.25 ],
185- [- 0.25 , 0.25 , - 0.25 ],
186- [- 0.25 , - 0.25 , 0.25 ],
187- [0.25 , - 0.25 , 0.25 ],
188- [0.25 , 0.25 , 0.25 ],
189- [- 0.25 , 0.25 , 0.25 ],
191+ [- offset_x , - offset_y , - offset_z ],
192+ [offset_x , - offset_y , - offset_z ],
193+ [offset_x , offset_y , - offset_z ],
194+ [- offset_x , offset_y , - offset_z ],
195+ [- offset_x , - offset_y , offset_z ],
196+ [offset_x , - offset_y , offset_z ],
197+ [offset_x , offset_y , offset_z ],
198+ [- offset_x , offset_y , offset_z ],
190199 ]
191- ).reshape ((1 , 8 , 3 )).repeat (voxel_sizes .shape [0 ], axis = 0 ) * voxel_sizes .reshape (
192- (n_hex_subd , 1 , 3 )
193- ).repeat (8 , axis = 1 )
194-
195- for cell in range (n_hex_each ):
196- center = center_points + cell_offsets [:, cell ]
197-
198- for corner in range (8 ):
199- new_pts_coords = new_pts_coords .at [
200- jnp .arange (n_hex_subd ) * 8 * n_hex_each + cell * n_hex_each + corner
201- ].set (center + offsets [:, corner ])
202-
203- new_hex_cells = new_hex_cells .at [
204- jnp .arange (n_hex_subd ) * n_hex_each + cell , corner
205- ].set (jnp .arange (n_hex_subd ) * 8 * n_hex_each + cell * n_hex_each + corner )
206-
207- # offsets = jnp.array(
208- # [
209- # [-0.25, -0.25, -0.25],
210- # [0.25, -0.25, -0.25],
211- # [0.25, 0.25, -0.25],
212- # [-0.25, 0.25, -0.25],
213- # [-0.25, -0.25, 0.25],
214- # [0.25, -0.25, 0.25],
215- # [0.25, 0.25, 0.25],
216- # [-0.25, 0.25, 0.25],
217- # ]
218- # ).reshape((1, 8, 3)).repeat(voxel_sizes.shape[0], axis=0) * voxel_sizes.reshape(
219- # (n_hex_new, 1, 3)
220- # ).repeat(8, axis=1)
221-
222- # for cell in range(8):
223- # center = center_points + offsets[:, cell]
224-
225- # for corner in range(8):
226- # new_pts_coords = new_pts_coords.at[
227- # jnp.arange(n_hex_new) * 64 + cell * 8 + corner
228- # ].set(center + offsets[:, corner])
229-
230- # new_hex_cells = new_hex_cells.at[
231- # jnp.arange(n_hex_new) * 8 + cell, corner
232- # ].set(jnp.arange(n_hex_new) * 64 + cell * 8 + corner)
200+ )
201+
202+ # Repeat the point offsets and scale them by the corresponding hex sizes
203+ # -> point_offset tensor of shape (n_hex_to_subdiv, n_points_per_hex, 3)
204+ point_offsets = point_offsets .reshape ((1 , n_points_per_hex , 3 )).repeat (
205+ hex_sizes .shape [0 ], axis = 0
206+ ) * hex_sizes .reshape ((n_hex_to_subdiv , 1 , 3 )).repeat (n_points_per_hex , axis = 1 )
207+
208+ # Repeat the two offsets at an additional axis to get all combinations
209+ cell_offsets = cell_offsets .reshape ((n_hex_to_subdiv , n_hex_each_hex , 1 , 3 )).repeat (
210+ n_points_per_hex , axis = 2
211+ )
212+ point_offsets = point_offsets .reshape (
213+ (n_hex_to_subdiv , 1 , n_points_per_hex , 3 )
214+ ).repeat (n_hex_each_hex , axis = 1 )
215+
216+ # Compute total offset relative to old hex center
217+ # -> (n_hex_to_subdiv, n_hex_each_hex, n_points_per_hex, 3)
218+ total_offsets = cell_offsets + point_offsets
219+
220+ # lets reshape the center points to broadcast
221+ center_points = (
222+ center_points .reshape ((n_hex_to_subdiv , 1 , 1 , 3 ))
223+ .repeat (n_hex_each_hex , axis = 1 )
224+ .repeat (n_points_per_hex , axis = 2 )
225+ )
226+
227+ # Directly compute new point coordinates and reshape
228+ new_pts_coords = (center_points + total_offsets ).reshape ((n_new_pts , 3 ))
229+ # Compute new hex cell indices
230+ new_hex_cells = jnp .linspace (0 , n_new_pts - 1 , n_new_pts , dtype = jnp .int32 ).reshape (
231+ (n_new_cells , n_points_per_hex )
232+ )
233233
234234 def reindex_and_mask (
235235 coords : jnp .ndarray , cells : jnp .ndarray , keep_mask : jnp .ndarray
@@ -249,9 +249,6 @@ def reindex_and_mask(
249249
250250 return coords , cells
251251
252- # new_pts_coords, new_hex_cells = reindex_and_mask(
253- # new_pts_coords, new_hex_cells, mask.repeat(8)
254- # )
255252 old_pts_coords , old_hex_cells = reindex_and_mask (
256253 pts_coords , hex_cells , jnp .logical_not (mask )
257254 )
@@ -297,21 +294,22 @@ def recursive_subdivide_hex_mesh(
297294 Returns:
298295 Subdivided points and hex cells.
299296 """
300- # lets build the kd-tree for fast nearest neighbor search
301297 xs = jnp .linspace (- Lx / 2 , Lx / 2 , sizing_field .shape [0 ])
302298 ys = jnp .linspace (- Ly / 2 , Ly / 2 , sizing_field .shape [1 ])
303299 zs = jnp .linspace (- Lz / 2 , Lz / 2 , sizing_field .shape [2 ])
304300
305301 interpolator = RegularGridInterpolator (
306- (xs , ys , zs ), sizing_field , method = "linear " , bounds_error = False , fill_value = - 1
302+ (xs , ys , zs ), sizing_field , method = "nearest " , bounds_error = False , fill_value = - 1
307303 )
308304
309305 for i in range (levels ):
310306 voxel_sizes = jnp .abs (pts_coords [hex_cells [:, 6 ]] - pts_coords [hex_cells [:, 0 ]])
311307
312- voxel_center_points = jnp .mean (pts_coords [hex_cells ], axis = 1 )
313- sizing_values = interpolator (voxel_center_points )
314- subdivision_mask = jnp .max (voxel_sizes , axis = - 1 ) > sizing_values
308+ # voxel_center_points = jnp.mean(pts_coords[hex_cells], axis=1)
309+ sizing_values_pts = interpolator (pts_coords )
310+ voxel_sizing_min = jnp .min (sizing_values_pts [hex_cells ], axis = 1 )
311+
312+ subdivision_mask = jnp .max (voxel_sizes , axis = - 1 ) > voxel_sizing_min
315313
316314 if not jnp .any (subdivision_mask ):
317315 print (f"No more subdivisions needed at level { i } ." )
@@ -421,6 +419,7 @@ def apply(inputs: InputSchema) -> OutputSchema:
421419 )
422420
423421 cell_centers = jnp .mean (pts [cells ], axis = 1 )
422+
424423 cell_values = interpolator (cell_centers )
425424
426425 cell_values_padded = jnp .zeros ((inputs .max_cells ,), dtype = cell_values .dtype )
0 commit comments