Skip to content

Commit e7c88a4

Browse files
committed
debugging jax-fem solver
1 parent 93d6a4e commit e7c88a4

File tree

3 files changed

+285
-182
lines changed

3 files changed

+285
-182
lines changed

examples/ansys/demo_2.ipynb

Lines changed: 240 additions & 42 deletions
Large diffs are not rendered by default.
278 KB
Binary file not shown.

examples/ansys/fem_tess/tesseract_api.py

Lines changed: 45 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55
import jax
66
import jax.numpy as jnp
7-
import numpy as np
87
import meshio
9-
from jax_fem.generate_mesh import Mesh, box_mesh, get_meshio_cell_type
8+
from jax_fem.generate_mesh import Mesh
109

1110
# Import JAX-FEM specific modules
1211
from jax_fem.problem import Problem
@@ -64,31 +63,6 @@ class InputSchema(BaseModel):
6463
hex_mesh: HexMesh = Field(
6564
description="Hexahedral mesh representation of the geometry",
6665
)
67-
Lx: float = Field(
68-
default=60.0, description="Length of the simulation box in the x direction."
69-
)
70-
Ly: float = Field(
71-
default=30.0,
72-
description=("Length of the simulation box in the y direction."),
73-
)
74-
Lz: float = Field(
75-
default=30.0, description="Length of the simulation box in the z direction."
76-
)
77-
Nx: int = Field(
78-
default=60,
79-
description=("Number of elements in the x direction."),
80-
)
81-
Ny: int = Field(
82-
default=30,
83-
description=("Number of elements in the y direction."),
84-
)
85-
Nz: int = Field(
86-
default=30,
87-
description=("Number of elements in the z direction."),
88-
)
89-
use_regular_grid: bool = Field(
90-
description="Toggle to use a regular grid mesh instead of imported mesh",
91-
)
9266

9367

9468
class OutputSchema(BaseModel):
@@ -131,7 +105,7 @@ def get_tensor_map(self) -> Callable:
131105
Callable that computes stress from strain gradient and density.
132106
"""
133107

134-
def stress(u_grad, theta): # noqa: ANN001
108+
def stress(u_grad, theta):
135109
Emax = 70.0e3
136110
Emin = 1e-3 * Emax
137111
nu = 0.3
@@ -166,12 +140,13 @@ def get_surface_maps(self) -> list[Callable]:
166140
Returns:
167141
List of van Neumann boundary condition value functions.
168142
"""
169-
# def surface_map(u, x):
170-
# return jnp.array([0.0, 0.0, 100.0])
171143

172-
# return [surface_map]
144+
def surface_map(u, x):
145+
return jnp.array([0.0, 0.0, 100.0])
173146

174-
return self.van_neumann_value_fns
147+
return [surface_map]
148+
149+
# return self.van_neumann_value_fns
175150

176151
def set_params(self, params: jnp.ndarray) -> None:
177152
"""Set density parameters for topology optimization.
@@ -205,20 +180,14 @@ def compute_compliance(self, sol: jnp.ndarray) -> jnp.ndarray:
205180
u_face = jnp.sum(u_face, axis=2)
206181
subset_quad_points = self.physical_surface_quad_points[0]
207182
neumann_fn = self.get_surface_maps()[0]
208-
traction = jax.vmap(jax.vmap(neumann_fn))(u_face, subset_quad_points)
183+
traction = -jax.vmap(jax.vmap(neumann_fn))(u_face, subset_quad_points)
209184
val = jnp.sum(traction * u_face * nanson_scale[:, :, None])
210185
return val
211186

212187

213188
# Memoize the setup function to avoid expensive recomputation
214189
# @lru_cache(maxsize=1)
215190
def setup(
216-
Nx: int = 60,
217-
Ny: int = 30,
218-
Nz: int = 30,
219-
Lx: float = 60.0,
220-
Ly: float = 30.0,
221-
Lz: float = 30.0,
222191
pts: jnp.ndarray = None,
223192
cells: jnp.ndarray = None,
224193
dirichlet_mask: jnp.ndarray = None,
@@ -229,12 +198,6 @@ def setup(
229198
"""Setup the elasticity problem and its differentiable solver.
230199
231200
Args:
232-
Nx: Number of elements in x direction for regular grid.
233-
Ny: Number of elements in y direction for regular grid.
234-
Nz: Number of elements in z direction for regular grid.
235-
Lx: Length of the domain in x direction for regular grid.
236-
Ly: Length of the domain in y direction for regular grid.
237-
Lz: Length of the domain in z direction for regular grid.
238201
pts: Optional array of mesh vertex positions for custom mesh.
239202
cells: Optional array of hexahedral cell definitions for custom mesh.
240203
dirichlet_mask: Mask array for Dirichlet boundary conditions.
@@ -247,56 +210,11 @@ def setup(
247210
problem instance and fwd_pred is the differentiable forward solver.
248211
"""
249212
ele_type = "HEX8"
250-
if pts is None and cells is None:
251-
cell_type = get_meshio_cell_type(ele_type)
252-
meshio_mesh = box_mesh(
253-
Nx=Nx, Ny=Ny, Nz=Nz, domain_x=Lx, domain_y=Ly, domain_z=Lz
254-
)
255-
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])
256-
else:
257-
meshio_mesh = meshio.Mesh(points=pts, cells={"hexahedron": cells})
258-
mesh = Mesh(pts, meshio_mesh.cells_dict["hexahedron"])
259-
260-
def bc_factory(
261-
masks: jnp.ndarray,
262-
values: jnp.ndarray,
263-
is_van_neumann: bool = False,
264-
) -> tuple[list[Callable], list[Callable]]:
265-
location_functions = []
266-
value_functions = []
267-
for i in range(values.shape[0]):
268-
269-
def location_fn(point, index): # noqa: ANN001
270-
# return mask[index]
271-
return (
272-
jax.lax.dynamic_index_in_dim(masks, index, 0, keepdims=False) == i
273-
)
274-
275-
def value_fn(point): # noqa: ANN001
276-
return values[i]
277-
278-
def value_fn_vn(u, x): # noqa: ANN001
279-
return values[i]
280-
281-
location_functions.append(location_fn)
282-
value_functions.append(value_fn_vn if is_van_neumann else value_fn)
283-
284-
return location_functions, value_functions
285-
286-
dirichlet_values = np.array(dirichlet_values)
287-
van_neumann_values = np.array(van_neumann_values)
288-
289-
dirichlet_location_fns, dirichlet_value_fns = bc_factory(
290-
dirichlet_mask, dirichlet_values
291-
)
292213

293-
van_neumann_locations, van_neumann_value_fns = bc_factory(
294-
van_neumann_mask, van_neumann_values, is_van_neumann=True
295-
)
296-
297-
dirichlet_bc_info = [dirichlet_location_fns * 3, [0, 1, 2], dirichlet_value_fns * 3]
214+
meshio_mesh = meshio.Mesh(points=pts, cells={"hexahedron": cells})
215+
mesh = Mesh(pts, meshio_mesh.cells_dict["hexahedron"])
298216

299-
location_fns = van_neumann_locations
217+
print(f"pts min: {jnp.min(pts, axis=0)}, pts max: {jnp.max(pts, axis=0)}")
300218

301219
# # Define boundary conditions and values.
302220
# def fixed_location(point):
@@ -343,6 +261,26 @@ def value_fn_vn(u, x): # noqa: ANN001
343261

344262
# location_fns = [load_location]
345263

264+
Lx = jnp.max(pts[:, 0]) - jnp.min(pts[:, 0])
265+
Ly = jnp.max(pts[:, 1]) - jnp.min(pts[:, 1])
266+
# Lz = jnp.max(pts[:, 2]) - jnp.min(pts[:, 2])
267+
268+
def fixed_location(point):
269+
return jnp.isclose(point[0], 0.0, atol=1e-5)
270+
271+
def load_location(point):
272+
return jnp.logical_and(
273+
jnp.isclose(point[0], Lx, atol=1e-5),
274+
jnp.isclose(point[1], 0.0, atol=0.1 * Ly + 1e-5),
275+
)
276+
277+
def dirichlet_val(point):
278+
return 0.0
279+
280+
dirichlet_bc_info = [[fixed_location] * 3, [0, 1, 2], [dirichlet_val] * 3]
281+
282+
location_fns = [load_location]
283+
346284
# Define forward problem
347285
problem = Elasticity(
348286
mesh,
@@ -351,8 +289,8 @@ def value_fn_vn(u, x): # noqa: ANN001
351289
ele_type=ele_type,
352290
dirichlet_bc_info=dirichlet_bc_info,
353291
location_fns=location_fns,
354-
additional_info=(van_neumann_value_fns,),
355-
# additional_info=([0.1],)
292+
# additional_info=(van_neumann_value_fns,),
293+
additional_info=([0.1],),
356294
)
357295

358296
# Apply the automatic differentiation wrapper
@@ -374,53 +312,20 @@ def apply_fn(inputs: dict) -> dict:
374312
Returns:
375313
Dictionary containing the compliance of the structure.
376314
"""
377-
378-
if not inputs["use_regular_grid"]:
379-
problem, fwd_pred = setup(
380-
Nx=inputs["Nx"],
381-
Ny=inputs["Ny"],
382-
Nz=inputs["Nz"],
383-
Lx=inputs["Lx"],
384-
Ly=inputs["Ly"],
385-
Lz=inputs["Lz"],
386-
# pts=jax.lax.dynamic_slice_in_dim(
387-
# inputs["hex_mesh"]["points"],
388-
# 0,
389-
# inputs["hex_mesh"]["n_points"],
390-
# axis=0,
391-
# ),
392-
# cells=jax.lax.dynamic_slice_in_dim(
393-
# inputs["hex_mesh"]["faces"],
394-
# 0,
395-
# inputs["hex_mesh"]["n_faces"],
396-
# axis=0,
397-
# ),
398-
pts=inputs["hex_mesh"]["points"][: inputs["hex_mesh"]["n_points"]],
399-
cells=inputs["hex_mesh"]["faces"][: inputs["hex_mesh"]["n_faces"]],
400-
dirichlet_mask=inputs["dirichlet_mask"],
401-
dirichlet_values=inputs["dirichlet_values"],
402-
van_neumann_mask=inputs["van_neumann_mask"],
403-
van_neumann_values=inputs["van_neumann_values"],
404-
)
405-
else:
406-
problem, fwd_pred = setup(
407-
Nx=inputs["Nx"],
408-
Ny=inputs["Ny"],
409-
Nz=inputs["Nz"],
410-
Lx=inputs["Lx"],
411-
Ly=inputs["Ly"],
412-
Lz=inputs["Lz"],
413-
)
414-
print(f"Setup completed with mesh of {problem.fe.num_cells} elements.")
415-
if inputs["use_regular_grid"]:
416-
rho = inputs["rho"]
417-
else:
418-
rho = inputs["rho"][: inputs["hex_mesh"]["n_faces"]]
419-
# print(rho)
420-
421-
print(rho.shape)
315+
problem, fwd_pred = setup(
316+
pts=inputs["hex_mesh"]["points"][: inputs["hex_mesh"]["n_points"]],
317+
cells=inputs["hex_mesh"]["faces"][: inputs["hex_mesh"]["n_faces"]],
318+
dirichlet_mask=inputs["dirichlet_mask"],
319+
dirichlet_values=inputs["dirichlet_values"],
320+
van_neumann_mask=inputs["van_neumann_mask"],
321+
van_neumann_values=inputs["van_neumann_values"],
322+
)
323+
324+
rho = inputs["rho"][: inputs["hex_mesh"]["n_points"]]
325+
422326
sol_list = fwd_pred(rho)
423327
compliance = problem.compute_compliance(sol_list[0])
328+
424329
return {"compliance": compliance.astype(jnp.float32)}
425330

426331

0 commit comments

Comments
 (0)