@@ -117,7 +117,6 @@ def stress(u_grad, theta):
117117 lmbda = E * nu / ((1 + nu ) * (1 - 2 * nu ))
118118
119119 epsilon = 0.5 * (u_grad + u_grad .T )
120- sigma = lmbda * jnp .trace (epsilon ) * jnp .eye (self .dim ) + 2 * mu * epsilon
121120
122121 sigma = lmbda * jnp .trace (epsilon ) * jnp .eye (self .dim ) + 2.0 * mu * epsilon
123122 return sigma
@@ -141,6 +140,7 @@ def set_params(self, params: jnp.ndarray) -> None:
141140 # Override base class method.
142141 full_params = jnp .ones ((self .fe .num_cells , params .shape [1 ]))
143142 full_params = full_params .at [self .fe .flex_inds ].set (params )
143+ print (self .fe .num_quads )
144144 thetas = jnp .repeat (full_params [:, None , :], self .fe .num_quads , axis = 1 )
145145 self .full_params = full_params
146146 self .internal_vars = [thetas ]
@@ -209,6 +209,7 @@ def bc_factory(
209209 # Create a factory that captures the current value of i
210210 def make_location_fn (idx ):
211211 def location_fn (point , index ):
212+ # jax.debug.print("Mask at point {}: {}", point, jax.lax.dynamic_index_in_dim(masks, index, 0, keepdims=False))
212213 return (
213214 jax .lax .dynamic_index_in_dim (masks , index , 0 , keepdims = False )
214215 == idx
@@ -218,12 +219,18 @@ def location_fn(point, index):
218219
219220 def make_value_fn (idx ):
220221 def value_fn (point ):
222+ # jax.debug.print("Value {} at point {}", jax.lax.dynamic_index_in_dim(values, idx, 0, keepdims=False), point)
221223 return values [idx ]
222224
223225 return value_fn
224226
225227 def make_value_fn_vn (idx ):
226228 def value_fn_vn (u , x ):
229+ jax .debug .print (
230+ "Van Neumann Value {} at point {}" ,
231+ jax .lax .dynamic_index_in_dim (values , idx , 0 , keepdims = False ),
232+ x ,
233+ )
227234 return values [idx ]
228235
229236 return value_fn_vn
@@ -238,6 +245,9 @@ def value_fn_vn(u, x):
238245 dirichlet_values = jnp .array (dirichlet_values )
239246 van_neumann_values = jnp .array (van_neumann_values )
240247
248+ print (f"dirichlet_values: { dirichlet_values } " )
249+ print (f"van_neumann_values: { van_neumann_values } " )
250+
241251 dirichlet_location_fns , dirichlet_value_fns = bc_factory (
242252 dirichlet_mask , dirichlet_values
243253 )
@@ -281,40 +291,6 @@ def apply_fn(inputs: dict) -> dict:
281291 Returns:
282292 Dictionary containing the compliance of the structure.
283293 """
284- from typing import TypeVar
285-
286- T = TypeVar ("T" )
287-
288- def stop_grads_int (x : T ) -> T :
289- """Stops gradient computation.
290-
291- We cannot use jax.lax.stop_gradient directly because Tesseract meshes are
292- nested dictionaries with arrays and integers, and jax.lax.stop_gradient
293- does not support integers.
294-
295- Args:
296- x: Input value.
297-
298- Returns:
299- Value with stopped gradients.
300- """
301-
302- def stop (x ):
303- return jax ._src .ad_util .stop_gradient_p .bind (x )
304-
305- return jax .tree_util .tree_map (stop , x )
306-
307- # stop grads on all inputs except rho
308-
309- # problem, fwd_pred = setup(
310- # pts=stop_grads_int(inputs["hex_mesh"]["points"][: inputs["hex_mesh"]["n_points"]]),
311- # cells=stop_grads_int(inputs["hex_mesh"]["faces"][: inputs["hex_mesh"]["n_faces"]]),
312- # dirichlet_mask=stop_grads_int(inputs["dirichlet_mask"]),
313- # dirichlet_values=stop_grads_int(inputs["dirichlet_values"]),
314- # van_neumann_mask=stop_grads_int(inputs["van_neumann_mask"]),
315- # van_neumann_values=stop_grads_int(inputs["van_neumann_values"]),
316- # )
317-
318294 # no stop grads
319295 problem , fwd_pred = setup (
320296 pts = inputs ["hex_mesh" ]["points" ][: inputs ["hex_mesh" ]["n_points" ]],
0 commit comments