44
55import jax
66import jax .numpy as jnp
7- import numpy as np
87import meshio
9- from jax_fem .generate_mesh import Mesh , box_mesh , get_meshio_cell_type
8+ from jax_fem .generate_mesh import Mesh
109
1110# Import JAX-FEM specific modules
1211from jax_fem .problem import Problem
@@ -64,31 +63,6 @@ class InputSchema(BaseModel):
6463 hex_mesh : HexMesh = Field (
6564 description = "Hexahedral mesh representation of the geometry" ,
6665 )
67- Lx : float = Field (
68- default = 60.0 , description = "Length of the simulation box in the x direction."
69- )
70- Ly : float = Field (
71- default = 30.0 ,
72- description = ("Length of the simulation box in the y direction." ),
73- )
74- Lz : float = Field (
75- default = 30.0 , description = "Length of the simulation box in the z direction."
76- )
77- Nx : int = Field (
78- default = 60 ,
79- description = ("Number of elements in the x direction." ),
80- )
81- Ny : int = Field (
82- default = 30 ,
83- description = ("Number of elements in the y direction." ),
84- )
85- Nz : int = Field (
86- default = 30 ,
87- description = ("Number of elements in the z direction." ),
88- )
89- use_regular_grid : bool = Field (
90- description = "Toggle to use a regular grid mesh instead of imported mesh" ,
91- )
9266
9367
9468class OutputSchema (BaseModel ):
@@ -131,7 +105,7 @@ def get_tensor_map(self) -> Callable:
131105 Callable that computes stress from strain gradient and density.
132106 """
133107
134- def stress (u_grad , theta ): # noqa: ANN001
108+ def stress (u_grad , theta ):
135109 Emax = 70.0e3
136110 Emin = 1e-3 * Emax
137111 nu = 0.3
@@ -166,12 +140,13 @@ def get_surface_maps(self) -> list[Callable]:
166140 Returns:
167141 List of van Neumann boundary condition value functions.
168142 """
169- # def surface_map(u, x):
170- # return jnp.array([0.0, 0.0, 100.0])
171143
172- # return [surface_map]
144+ def surface_map (u , x ):
145+ return jnp .array ([0.0 , 0.0 , 100.0 ])
173146
174- return self .van_neumann_value_fns
147+ return [surface_map ]
148+
149+ # return self.van_neumann_value_fns
175150
176151 def set_params (self , params : jnp .ndarray ) -> None :
177152 """Set density parameters for topology optimization.
@@ -205,20 +180,14 @@ def compute_compliance(self, sol: jnp.ndarray) -> jnp.ndarray:
205180 u_face = jnp .sum (u_face , axis = 2 )
206181 subset_quad_points = self .physical_surface_quad_points [0 ]
207182 neumann_fn = self .get_surface_maps ()[0 ]
208- traction = jax .vmap (jax .vmap (neumann_fn ))(u_face , subset_quad_points )
183+ traction = - jax .vmap (jax .vmap (neumann_fn ))(u_face , subset_quad_points )
209184 val = jnp .sum (traction * u_face * nanson_scale [:, :, None ])
210185 return val
211186
212187
213188# Memoize the setup function to avoid expensive recomputation
214189# @lru_cache(maxsize=1)
215190def setup (
216- Nx : int = 60 ,
217- Ny : int = 30 ,
218- Nz : int = 30 ,
219- Lx : float = 60.0 ,
220- Ly : float = 30.0 ,
221- Lz : float = 30.0 ,
222191 pts : jnp .ndarray = None ,
223192 cells : jnp .ndarray = None ,
224193 dirichlet_mask : jnp .ndarray = None ,
@@ -229,12 +198,6 @@ def setup(
229198 """Setup the elasticity problem and its differentiable solver.
230199
231200 Args:
232- Nx: Number of elements in x direction for regular grid.
233- Ny: Number of elements in y direction for regular grid.
234- Nz: Number of elements in z direction for regular grid.
235- Lx: Length of the domain in x direction for regular grid.
236- Ly: Length of the domain in y direction for regular grid.
237- Lz: Length of the domain in z direction for regular grid.
238201 pts: Optional array of mesh vertex positions for custom mesh.
239202 cells: Optional array of hexahedral cell definitions for custom mesh.
240203 dirichlet_mask: Mask array for Dirichlet boundary conditions.
@@ -247,56 +210,11 @@ def setup(
247210 problem instance and fwd_pred is the differentiable forward solver.
248211 """
249212 ele_type = "HEX8"
250- if pts is None and cells is None :
251- cell_type = get_meshio_cell_type (ele_type )
252- meshio_mesh = box_mesh (
253- Nx = Nx , Ny = Ny , Nz = Nz , domain_x = Lx , domain_y = Ly , domain_z = Lz
254- )
255- mesh = Mesh (meshio_mesh .points , meshio_mesh .cells_dict [cell_type ])
256- else :
257- meshio_mesh = meshio .Mesh (points = pts , cells = {"hexahedron" : cells })
258- mesh = Mesh (pts , meshio_mesh .cells_dict ["hexahedron" ])
259-
260- def bc_factory (
261- masks : jnp .ndarray ,
262- values : jnp .ndarray ,
263- is_van_neumann : bool = False ,
264- ) -> tuple [list [Callable ], list [Callable ]]:
265- location_functions = []
266- value_functions = []
267- for i in range (values .shape [0 ]):
268-
269- def location_fn (point , index ): # noqa: ANN001
270- # return mask[index]
271- return (
272- jax .lax .dynamic_index_in_dim (masks , index , 0 , keepdims = False ) == i
273- )
274-
275- def value_fn (point ): # noqa: ANN001
276- return values [i ]
277-
278- def value_fn_vn (u , x ): # noqa: ANN001
279- return values [i ]
280-
281- location_functions .append (location_fn )
282- value_functions .append (value_fn_vn if is_van_neumann else value_fn )
283-
284- return location_functions , value_functions
285-
286- dirichlet_values = np .array (dirichlet_values )
287- van_neumann_values = np .array (van_neumann_values )
288-
289- dirichlet_location_fns , dirichlet_value_fns = bc_factory (
290- dirichlet_mask , dirichlet_values
291- )
292213
293- van_neumann_locations , van_neumann_value_fns = bc_factory (
294- van_neumann_mask , van_neumann_values , is_van_neumann = True
295- )
296-
297- dirichlet_bc_info = [dirichlet_location_fns * 3 , [0 , 1 , 2 ], dirichlet_value_fns * 3 ]
214+ meshio_mesh = meshio .Mesh (points = pts , cells = {"hexahedron" : cells })
215+ mesh = Mesh (pts , meshio_mesh .cells_dict ["hexahedron" ])
298216
299- location_fns = van_neumann_locations
217+ print ( f"pts min: { jnp . min ( pts , axis = 0 ) } , pts max: { jnp . max ( pts , axis = 0 ) } " )
300218
301219 # # Define boundary conditions and values.
302220 # def fixed_location(point):
@@ -343,6 +261,26 @@ def value_fn_vn(u, x): # noqa: ANN001
343261
344262 # location_fns = [load_location]
345263
264+ Lx = jnp .max (pts [:, 0 ]) - jnp .min (pts [:, 0 ])
265+ Ly = jnp .max (pts [:, 1 ]) - jnp .min (pts [:, 1 ])
266+ # Lz = jnp.max(pts[:, 2]) - jnp.min(pts[:, 2])
267+
268+ def fixed_location (point ):
269+ return jnp .isclose (point [0 ], 0.0 , atol = 1e-5 )
270+
271+ def load_location (point ):
272+ return jnp .logical_and (
273+ jnp .isclose (point [0 ], Lx , atol = 1e-5 ),
274+ jnp .isclose (point [1 ], 0.0 , atol = 0.1 * Ly + 1e-5 ),
275+ )
276+
277+ def dirichlet_val (point ):
278+ return 0.0
279+
280+ dirichlet_bc_info = [[fixed_location ] * 3 , [0 , 1 , 2 ], [dirichlet_val ] * 3 ]
281+
282+ location_fns = [load_location ]
283+
346284 # Define forward problem
347285 problem = Elasticity (
348286 mesh ,
@@ -351,8 +289,8 @@ def value_fn_vn(u, x): # noqa: ANN001
351289 ele_type = ele_type ,
352290 dirichlet_bc_info = dirichlet_bc_info ,
353291 location_fns = location_fns ,
354- additional_info = (van_neumann_value_fns ,),
355- # additional_info=([0.1],)
292+ # additional_info=(van_neumann_value_fns,),
293+ additional_info = ([0.1 ],),
356294 )
357295
358296 # Apply the automatic differentiation wrapper
@@ -374,53 +312,20 @@ def apply_fn(inputs: dict) -> dict:
374312 Returns:
375313 Dictionary containing the compliance of the structure.
376314 """
377-
378- if not inputs ["use_regular_grid" ]:
379- problem , fwd_pred = setup (
380- Nx = inputs ["Nx" ],
381- Ny = inputs ["Ny" ],
382- Nz = inputs ["Nz" ],
383- Lx = inputs ["Lx" ],
384- Ly = inputs ["Ly" ],
385- Lz = inputs ["Lz" ],
386- # pts=jax.lax.dynamic_slice_in_dim(
387- # inputs["hex_mesh"]["points"],
388- # 0,
389- # inputs["hex_mesh"]["n_points"],
390- # axis=0,
391- # ),
392- # cells=jax.lax.dynamic_slice_in_dim(
393- # inputs["hex_mesh"]["faces"],
394- # 0,
395- # inputs["hex_mesh"]["n_faces"],
396- # axis=0,
397- # ),
398- pts = inputs ["hex_mesh" ]["points" ][: inputs ["hex_mesh" ]["n_points" ]],
399- cells = inputs ["hex_mesh" ]["faces" ][: inputs ["hex_mesh" ]["n_faces" ]],
400- dirichlet_mask = inputs ["dirichlet_mask" ],
401- dirichlet_values = inputs ["dirichlet_values" ],
402- van_neumann_mask = inputs ["van_neumann_mask" ],
403- van_neumann_values = inputs ["van_neumann_values" ],
404- )
405- else :
406- problem , fwd_pred = setup (
407- Nx = inputs ["Nx" ],
408- Ny = inputs ["Ny" ],
409- Nz = inputs ["Nz" ],
410- Lx = inputs ["Lx" ],
411- Ly = inputs ["Ly" ],
412- Lz = inputs ["Lz" ],
413- )
414- print (f"Setup completed with mesh of { problem .fe .num_cells } elements." )
415- if inputs ["use_regular_grid" ]:
416- rho = inputs ["rho" ]
417- else :
418- rho = inputs ["rho" ][: inputs ["hex_mesh" ]["n_faces" ]]
419- # print(rho)
420-
421- print (rho .shape )
315+ problem , fwd_pred = setup (
316+ pts = inputs ["hex_mesh" ]["points" ][: inputs ["hex_mesh" ]["n_points" ]],
317+ cells = inputs ["hex_mesh" ]["faces" ][: inputs ["hex_mesh" ]["n_faces" ]],
318+ dirichlet_mask = inputs ["dirichlet_mask" ],
319+ dirichlet_values = inputs ["dirichlet_values" ],
320+ van_neumann_mask = inputs ["van_neumann_mask" ],
321+ van_neumann_values = inputs ["van_neumann_values" ],
322+ )
323+
324+ rho = inputs ["rho" ][: inputs ["hex_mesh" ]["n_points" ]]
325+
422326 sol_list = fwd_pred (rho )
423327 compliance = problem .compute_compliance (sol_list [0 ])
328+
424329 return {"compliance" : compliance .astype (jnp .float32 )}
425330
426331
0 commit comments