@@ -131,6 +131,7 @@ def vectorized_subdivide_hex_mesh(
131131 This method introduces duplicates of points that should later be merged.
132132
133133 Hexahedron is constructed as follows:
134+
134135 3 -------- 2
135136 /| /|
136137 7 -------- 6 |
@@ -140,23 +141,26 @@ def vectorized_subdivide_hex_mesh(
140141 4 -------- 5
141142
142143 Axis orientation:
143- y
144- |
145- |____ x
146- /
147- /
148- z
144+
145+ y
146+ |
147+ |____ x
148+ /
149+ /
150+ z
149151
150152 """
151- n_hex = hex_cells . shape [ 0 ]
152- n_new_pts = (8 * 8 ) * n_hex # 8 corners per new hex, 8 new hexes per old hex
153+ n_hex_new = mask . sum ()
154+ n_new_pts = (8 * 8 ) * n_hex_new # 8 corners per new hex, 8 new hexes per old hex
153155
154156 new_pts_coords = jnp .zeros ((n_new_pts , 3 ), dtype = pts_coords .dtype )
155- new_hex_cells = jnp .zeros ((n_hex * 8 , 8 ), dtype = hex_cells .dtype )
157+ new_hex_cells = jnp .zeros ((n_hex_new * 8 , 8 ), dtype = hex_cells .dtype )
156158
157- voxel_sizes = jnp .abs (pts_coords [hex_cells [:, 6 ]] - pts_coords [hex_cells [:, 0 ]])
159+ voxel_sizes = jnp .abs (
160+ pts_coords [hex_cells [mask , 6 ]] - pts_coords [hex_cells [mask , 0 ]]
161+ )
158162
159- center_points = jnp .mean (pts_coords [hex_cells ], axis = 1 ) # (n_hex, 3)
163+ center_points = jnp .mean (pts_coords [hex_cells [ mask ] ], axis = 1 ) # (n_hex, 3)
160164 offsets = jnp .array (
161165 [
162166 [- 0.25 , - 0.25 , - 0.25 ],
@@ -169,20 +173,20 @@ def vectorized_subdivide_hex_mesh(
169173 [- 0.25 , 0.25 , 0.25 ],
170174 ]
171175 ).reshape ((1 , 8 , 3 )).repeat (voxel_sizes .shape [0 ], axis = 0 ) * voxel_sizes .reshape (
172- (n_hex , 1 , 3 )
176+ (n_hex_new , 1 , 3 )
173177 ).repeat (8 , axis = 1 )
174178
175179 for cell in range (8 ):
176180 center = center_points + offsets [:, cell ]
177181
178182 for corner in range (8 ):
179183 new_pts_coords = new_pts_coords .at [
180- jnp .arange (n_hex ) * 64 + cell * 8 + corner
184+ jnp .arange (n_hex_new ) * 64 + cell * 8 + corner
181185 ].set (center - offsets [:, corner ])
182186
183- new_hex_cells = new_hex_cells .at [jnp . arange ( n_hex ) * 8 + cell , corner ]. set (
184- jnp .arange (n_hex ) * 64 + cell * 8 + corner
185- )
187+ new_hex_cells = new_hex_cells .at [
188+ jnp .arange (n_hex_new ) * 8 + cell , corner
189+ ]. set ( jnp . arange ( n_hex_new ) * 64 + cell * 8 + corner )
186190
187191 def reindex_and_mask (
188192 coords : jnp .ndarray , cells : jnp .ndarray , keep_mask : jnp .ndarray
@@ -202,9 +206,9 @@ def reindex_and_mask(
202206
203207 return coords , cells
204208
205- new_pts_coords , new_hex_cells = reindex_and_mask (
206- new_pts_coords , new_hex_cells , mask .repeat (8 )
207- )
209+ # new_pts_coords, new_hex_cells = reindex_and_mask(
210+ # new_pts_coords, new_hex_cells, mask.repeat(8)
211+ # )
208212 old_pts_coords , old_hex_cells = reindex_and_mask (
209213 pts_coords , hex_cells , jnp .logical_not (mask )
210214 )
@@ -279,7 +283,6 @@ def recursive_subdivide_hex_mesh(
279283
280284 return pts_coords , hex_cells
281285
282- mesh = None # cache for the last generated mesh
283286
284287# @lru_cache(maxsize=1)
285288def generate_mesh (
@@ -314,8 +317,6 @@ def generate_mesh(
314317 Lz = Lz ,
315318 )
316319
317- mesh = (pts , cells ) # cache the generated mesh
318-
319320 return pts , cells
320321
321322
@@ -397,17 +398,14 @@ def vector_jacobian_product(
397398 assert vjp_inputs == {"field_values" }
398399 assert vjp_outputs == {"mesh_cell_values" }
399400
400- if mesh is None :
401- pts , cells = generate_mesh (
402- Lx = inputs .Lx ,
403- Ly = inputs .Ly ,
404- Lz = inputs .Lz ,
405- sizing_field = inputs .sizing_field ,
406- max_levels = inputs .max_subdivision_levels ,
407- )
408- else :
409- print ("Using cached mesh for VJP computation." )
410- pts , cells = mesh
401+ pts , cells = generate_mesh (
402+ Lx = inputs .Lx ,
403+ Ly = inputs .Ly ,
404+ Lz = inputs .Lz ,
405+ sizing_field = inputs .sizing_field ,
406+ max_levels = inputs .max_subdivision_levels ,
407+ )
408+
411409 cell_centers = jnp .mean (pts [cells ], axis = 1 )
412410
413411 xs = jnp .linspace (- inputs .Lx / 2 , inputs .Lx / 2 , inputs .field_values .shape [0 ])
0 commit comments