Skip to content

Commit 624b97e

Browse files
Merge branch 'ansys' of https://github.com/pasteurlabs/tesseract-jax into spaceclaim_tesseract
2 parents af687fd + e388269 commit 624b97e

File tree

4 files changed

+116
-182
lines changed

4 files changed

+116
-182
lines changed

examples/ansys/demo_2.ipynb

Lines changed: 62 additions & 93 deletions
Large diffs are not rendered by default.
498 KB
Binary file not shown.

examples/ansys/fem_tess/tesseract_api.py

Lines changed: 46 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -106,30 +106,13 @@ def get_tensor_map(self) -> Callable:
106106
"""
107107

108108
def stress(u_grad, theta):
109-
Emax = 70.0e3
110-
Emin = 1e-3 * Emax
109+
E = 70e3
111110
nu = 0.3
112-
penal = 3.0
113-
E = Emin + (Emax - Emin) * theta[0] ** penal
114-
epsilon = 0.5 * (u_grad + u_grad.T)
115-
# eps11 = epsilon[0, 0]
116-
# eps22 = epsilon[1, 1]
117-
# eps12 = epsilon[0, 1]
118-
# mu = E / (2 * (1 + nu))
119-
# sigma = jnp.trace(epsilon) * jnp.eye(self.dim) + 2*mu*epsilon
120-
# # sig11 = E / (1 + nu) / (1 - nu) * (eps11 + nu * eps22)
121-
# # sig22 = E / (1 + nu) / (1 - nu) * (nu * eps11 + eps22)
122-
# # sig12 = E / (1 + nu) * eps12
123-
# # sigma = jnp.array([[sig11, sig12], [sig12, sig22]])
124-
125-
# Correct 3D linear elasticity constitutive law
126-
# Lamé parameters
127-
lmbda = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu)) # First Lamé parameter
128-
mu = E / (2.0 * (1.0 + nu)) # Second Lamé parameter (shear modulus)
129-
130-
# Stress-strain relationship
131-
sigma = lmbda * jnp.trace(epsilon) * jnp.eye(self.dim) + 2.0 * mu * epsilon
111+
mu = E / (2.0 * (1.0 + nu))
112+
lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))
132113

114+
epsilon = 0.5 * (u_grad + u_grad.T)
115+
sigma = lmbda * jnp.trace(epsilon) * jnp.eye(self.dim) + 2 * mu * epsilon
133116
return sigma
134117

135118
return stress
@@ -210,76 +193,61 @@ def setup(
210193
problem instance and fwd_pred is the differentiable forward solver.
211194
"""
212195
ele_type = "HEX8"
213-
214196
meshio_mesh = meshio.Mesh(points=pts, cells={"hexahedron": cells})
215197
mesh = Mesh(pts, meshio_mesh.cells_dict["hexahedron"])
216198

217-
print(f"pts min: {jnp.min(pts, axis=0)}, pts max: {jnp.max(pts, axis=0)}")
218-
219-
# # Define boundary conditions and values.
220-
# def fixed_location(point):
221-
# return jnp.isclose(point[0], 0, atol=1e-5)
222-
223-
# print(Lx, Ly, Lz)
224-
225-
# def fixed_location(point):
226-
# # return jnp.isclose(point[0], -Lx / 3, atol=0.1)
227-
# return point[0] < (-Lx / 2 + 1e-5) # Left face
199+
def bc_factory(
200+
masks: jnp.ndarray,
201+
values: jnp.ndarray,
202+
is_van_neumann: bool = False,
203+
) -> tuple[list[Callable], list[Callable]]:
204+
location_functions = []
205+
value_functions = []
228206

229-
# def load_location(point):
207+
for i in range(values.shape[0]):
208+
# Create a factory that captures the current value of i
209+
def make_location_fn(idx):
210+
def location_fn(point, index):
211+
return (
212+
jax.lax.dynamic_index_in_dim(masks, index, 0, keepdims=False)
213+
== idx
214+
)
230215

231-
# # return jnp.logical_and(
232-
# # jnp.logical_and(
233-
# # jnp.isclose(point[0], Lx / 2, atol=1e-5),
234-
# # jnp.isclose(point[1], -Ly / 2, atol=1e-5),
235-
# # ),
236-
# # jnp.isclose(point[2], Lz / 2, atol=1e-5),
237-
# # )
216+
return location_fn
238217

239-
# return jnp.logical_and(
240-
# jnp.isclose(point[0], 0, atol=1e-5),
241-
# jnp.isclose(point[1], 0, atol=0.1 * Ly + 1e-5),
242-
# )
218+
def make_value_fn(idx):
219+
def value_fn(point):
220+
return values[idx]
243221

244-
# def dirichlet_val(point):
245-
# return 0.0
222+
return value_fn
246223

247-
# # # Define boundary conditions and values.
248-
# def fixed_location(point, index):
249-
# return jnp.isclose(point[0], -Lx/2, atol=0.1)
224+
def make_value_fn_vn(idx):
225+
def value_fn_vn(u, x):
226+
return values[idx]
250227

251-
# def load_location(point):
252-
# return jnp.logical_and(jnp.logical_and(
253-
# jnp.isclose(point[0], Lx/2, atol=1e-2),
254-
# jnp.isclose(point[2], -Lz/2, atol=1e-2),
255-
# ), jnp.isclose(point[1], Ly/2, atol=1e-2))
228+
return value_fn_vn
256229

257-
# def dirichlet_val(point):
258-
# return 0.0
230+
location_functions.append(make_location_fn(i))
231+
value_functions.append(
232+
make_value_fn_vn(i) if is_van_neumann else make_value_fn(i)
233+
)
259234

260-
# dirichlet_bc_info = [[fixed_location] * 3, [0, 1, 2], [dirichlet_val] * 3]
235+
return location_functions, value_functions
261236

262-
# location_fns = [load_location]
237+
dirichlet_values = jnp.array(dirichlet_values)
238+
van_neumann_values = jnp.array(van_neumann_values)
263239

264-
Lx = jnp.max(pts[:, 0]) - jnp.min(pts[:, 0])
265-
Ly = jnp.max(pts[:, 1]) - jnp.min(pts[:, 1])
266-
# Lz = jnp.max(pts[:, 2]) - jnp.min(pts[:, 2])
267-
268-
def fixed_location(point):
269-
return jnp.isclose(point[0], 0.0, atol=1e-5)
270-
271-
def load_location(point):
272-
return jnp.logical_and(
273-
jnp.isclose(point[0], Lx, atol=1e-5),
274-
jnp.isclose(point[1], 0.0, atol=0.1 * Ly + 1e-5),
275-
)
240+
dirichlet_location_fns, dirichlet_value_fns = bc_factory(
241+
dirichlet_mask, dirichlet_values
242+
)
276243

277-
def dirichlet_val(point):
278-
return 0.0
244+
van_neumann_locations, van_neumann_value_fns = bc_factory(
245+
van_neumann_mask, van_neumann_values, is_van_neumann=True
246+
)
279247

280-
dirichlet_bc_info = [[fixed_location] * 3, [0, 1, 2], [dirichlet_val] * 3]
248+
dirichlet_bc_info = [dirichlet_location_fns * 3, [0, 1, 2], dirichlet_value_fns * 3]
281249

282-
location_fns = [load_location]
250+
location_fns = van_neumann_locations
283251

284252
# Define forward problem
285253
problem = Elasticity(
@@ -289,8 +257,8 @@ def dirichlet_val(point):
289257
ele_type=ele_type,
290258
dirichlet_bc_info=dirichlet_bc_info,
291259
location_fns=location_fns,
292-
# additional_info=(van_neumann_value_fns,),
293-
additional_info=([0.1],),
260+
additional_info=(van_neumann_value_fns,),
261+
# additional_info=([0.1],),
294262
)
295263

296264
# Apply the automatic differentiation wrapper

examples/ansys/meshing_tess/tesseract_api.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,19 @@ def vectorized_subdivide_hex_mesh(
132132
133133
Hexahedron is constructed as follows:
134134
135-
3 -------- 2
135+
7 -------- 6
136136
/| /|
137-
7 -------- 6 |
137+
4 -------- 5 |
138138
| | | |
139-
| 0 -------|-1
139+
| 3 -------|-2
140140
|/ |/
141-
4 -------- 5
141+
0 -------- 1
142142
143143
Axis orientation:
144144
145-
y
146-
|
147-
|____ x
148-
/
149-
/
150-
z
145+
z y
146+
| /
147+
|/____ x
151148
152149
"""
153150
n_hex_new = mask.sum()
@@ -182,7 +179,7 @@ def vectorized_subdivide_hex_mesh(
182179
for corner in range(8):
183180
new_pts_coords = new_pts_coords.at[
184181
jnp.arange(n_hex_new) * 64 + cell * 8 + corner
185-
].set(center - offsets[:, corner])
182+
].set(center + offsets[:, corner])
186183

187184
new_hex_cells = new_hex_cells.at[
188185
jnp.arange(n_hex_new) * 8 + cell, corner

0 commit comments

Comments
 (0)