@@ -106,30 +106,13 @@ def get_tensor_map(self) -> Callable:
106106 """
107107
108108 def stress (u_grad , theta ):
109- Emax = 70.0e3
110- Emin = 1e-3 * Emax
109+ E = 70e3
111110 nu = 0.3
112- penal = 3.0
113- E = Emin + (Emax - Emin ) * theta [0 ] ** penal
114- epsilon = 0.5 * (u_grad + u_grad .T )
115- # eps11 = epsilon[0, 0]
116- # eps22 = epsilon[1, 1]
117- # eps12 = epsilon[0, 1]
118- # mu = E / (2 * (1 + nu))
119- # sigma = jnp.trace(epsilon) * jnp.eye(self.dim) + 2*mu*epsilon
120- # # sig11 = E / (1 + nu) / (1 - nu) * (eps11 + nu * eps22)
121- # # sig22 = E / (1 + nu) / (1 - nu) * (nu * eps11 + eps22)
122- # # sig12 = E / (1 + nu) * eps12
123- # # sigma = jnp.array([[sig11, sig12], [sig12, sig22]])
124-
125- # Correct 3D linear elasticity constitutive law
126- # Lamé parameters
127- lmbda = E * nu / ((1.0 + nu ) * (1.0 - 2.0 * nu )) # First Lamé parameter
128- mu = E / (2.0 * (1.0 + nu )) # Second Lamé parameter (shear modulus)
129-
130- # Stress-strain relationship
131- sigma = lmbda * jnp .trace (epsilon ) * jnp .eye (self .dim ) + 2.0 * mu * epsilon
111+ mu = E / (2.0 * (1.0 + nu ))
112+ lmbda = E * nu / ((1 + nu ) * (1 - 2 * nu ))
132113
114+ epsilon = 0.5 * (u_grad + u_grad .T )
115+ sigma = lmbda * jnp .trace (epsilon ) * jnp .eye (self .dim ) + 2 * mu * epsilon
133116 return sigma
134117
135118 return stress
@@ -210,76 +193,61 @@ def setup(
210193 problem instance and fwd_pred is the differentiable forward solver.
211194 """
212195 ele_type = "HEX8"
213-
214196 meshio_mesh = meshio .Mesh (points = pts , cells = {"hexahedron" : cells })
215197 mesh = Mesh (pts , meshio_mesh .cells_dict ["hexahedron" ])
216198
217- print (f"pts min: { jnp .min (pts , axis = 0 )} , pts max: { jnp .max (pts , axis = 0 )} " )
218-
219- # # Define boundary conditions and values.
220- # def fixed_location(point):
221- # return jnp.isclose(point[0], 0, atol=1e-5)
222-
223- # print(Lx, Ly, Lz)
224-
225- # def fixed_location(point):
226- # # return jnp.isclose(point[0], -Lx / 3, atol=0.1)
227- # return point[0] < (-Lx / 2 + 1e-5) # Left face
199+ def bc_factory (
200+ masks : jnp .ndarray ,
201+ values : jnp .ndarray ,
202+ is_van_neumann : bool = False ,
203+ ) -> tuple [list [Callable ], list [Callable ]]:
204+ location_functions = []
205+ value_functions = []
228206
229- # def load_location(point):
207+ for i in range (values .shape [0 ]):
208+ # Create a factory that captures the current value of i
209+ def make_location_fn (idx ):
210+ def location_fn (point , index ):
211+ return (
212+ jax .lax .dynamic_index_in_dim (masks , index , 0 , keepdims = False )
213+ == idx
214+ )
230215
231- # # return jnp.logical_and(
232- # # jnp.logical_and(
233- # # jnp.isclose(point[0], Lx / 2, atol=1e-5),
234- # # jnp.isclose(point[1], -Ly / 2, atol=1e-5),
235- # # ),
236- # # jnp.isclose(point[2], Lz / 2, atol=1e-5),
237- # # )
216+ return location_fn
238217
239- # return jnp.logical_and(
240- # jnp.isclose(point[0], 0, atol=1e-5),
241- # jnp.isclose(point[1], 0, atol=0.1 * Ly + 1e-5),
242- # )
218+ def make_value_fn (idx ):
219+ def value_fn (point ):
220+ return values [idx ]
243221
244- # def dirichlet_val(point):
245- # return 0.0
222+ return value_fn
246223
247- # # # Define boundary conditions and values.
248- # def fixed_location(point, index ):
249- # return jnp.isclose(point[0], -Lx/2, atol=0.1)
224+ def make_value_fn_vn ( idx ):
225+ def value_fn_vn ( u , x ):
226+ return values [ idx ]
250227
251- # def load_location(point):
252- # return jnp.logical_and(jnp.logical_and(
253- # jnp.isclose(point[0], Lx/2, atol=1e-2),
254- # jnp.isclose(point[2], -Lz/2, atol=1e-2),
255- # ), jnp.isclose(point[1], Ly/2, atol=1e-2))
228+ return value_fn_vn
256229
257- # def dirichlet_val(point):
258- # return 0.0
230+ location_functions .append (make_location_fn (i ))
231+ value_functions .append (
232+ make_value_fn_vn (i ) if is_van_neumann else make_value_fn (i )
233+ )
259234
260- # dirichlet_bc_info = [[fixed_location] * 3, [0, 1, 2], [dirichlet_val] * 3]
235+ return location_functions , value_functions
261236
262- # location_fns = [load_location]
237+ dirichlet_values = jnp .array (dirichlet_values )
238+ van_neumann_values = jnp .array (van_neumann_values )
263239
264- Lx = jnp .max (pts [:, 0 ]) - jnp .min (pts [:, 0 ])
265- Ly = jnp .max (pts [:, 1 ]) - jnp .min (pts [:, 1 ])
266- # Lz = jnp.max(pts[:, 2]) - jnp.min(pts[:, 2])
267-
268- def fixed_location (point ):
269- return jnp .isclose (point [0 ], 0.0 , atol = 1e-5 )
270-
271- def load_location (point ):
272- return jnp .logical_and (
273- jnp .isclose (point [0 ], Lx , atol = 1e-5 ),
274- jnp .isclose (point [1 ], 0.0 , atol = 0.1 * Ly + 1e-5 ),
275- )
240+ dirichlet_location_fns , dirichlet_value_fns = bc_factory (
241+ dirichlet_mask , dirichlet_values
242+ )
276243
277- def dirichlet_val (point ):
278- return 0.0
244+ van_neumann_locations , van_neumann_value_fns = bc_factory (
245+ van_neumann_mask , van_neumann_values , is_van_neumann = True
246+ )
279247
280- dirichlet_bc_info = [[ fixed_location ] * 3 , [0 , 1 , 2 ], [ dirichlet_val ] * 3 ]
248+ dirichlet_bc_info = [dirichlet_location_fns * 3 , [0 , 1 , 2 ], dirichlet_value_fns * 3 ]
281249
282- location_fns = [ load_location ]
250+ location_fns = van_neumann_locations
283251
284252 # Define forward problem
285253 problem = Elasticity (
@@ -289,8 +257,8 @@ def dirichlet_val(point):
289257 ele_type = ele_type ,
290258 dirichlet_bc_info = dirichlet_bc_info ,
291259 location_fns = location_fns ,
292- # additional_info=(van_neumann_value_fns,),
293- additional_info = ([0.1 ],),
260+ additional_info = (van_neumann_value_fns ,),
261+ # additional_info=([0.1],),
294262 )
295263
296264 # Apply the automatic differentiation wrapper
0 commit comments