@@ -76,6 +76,11 @@ class OutputSchema(BaseModel):
7676 ] = Field (description = "Compliance of the structure, a measure of stiffness" )
7777
7878
79+ # displacement: Array[
80+ # (None, 3),
81+ # Float32,
82+ # ] = Field(description="Nodal displacement field")
83+
7984#
8085# Helper functions
8186#
@@ -140,7 +145,6 @@ def set_params(self, params: jnp.ndarray) -> None:
140145 # Override base class method.
141146 full_params = jnp .ones ((self .fe .num_cells , params .shape [1 ]))
142147 full_params = full_params .at [self .fe .flex_inds ].set (params )
143- print (self .fe .num_quads )
144148 thetas = jnp .repeat (full_params [:, None , :], self .fe .num_quads , axis = 1 )
145149 self .full_params = full_params
146150 self .internal_vars = [thetas ]
@@ -209,28 +213,25 @@ def bc_factory(
209213 # Create a factory that captures the current value of i
210214 def make_location_fn (idx ):
211215 def location_fn (point , index ):
212- # jax.debug.print("Mask at point {}: {}", point, jax.lax.dynamic_index_in_dim(masks, index, 0, keepdims=False))
213216 return (
214- jax .lax .dynamic_index_in_dim (masks , index , 0 , keepdims = False )
215- == idx
216- )
217+ jnp .sum (
218+ jax .lax .dynamic_index_in_dim (
219+ masks , index , 0 , keepdims = False
220+ )
221+ )
222+ == idx + 1
223+ ).astype (jnp .bool_ )
217224
218225 return location_fn
219226
220227 def make_value_fn (idx ):
221228 def value_fn (point ):
222- # jax.debug.print("Value {} at point {}", jax.lax.dynamic_index_in_dim(values, idx, 0, keepdims=False), point)
223229 return values [idx ]
224230
225231 return value_fn
226232
227233 def make_value_fn_vn (idx ):
228234 def value_fn_vn (u , x ):
229- jax .debug .print (
230- "Van Neumann Value {} at point {}" ,
231- jax .lax .dynamic_index_in_dim (values , idx , 0 , keepdims = False ),
232- x ,
233- )
234235 return values [idx ]
235236
236237 return value_fn_vn
@@ -242,11 +243,8 @@ def value_fn_vn(u, x):
242243
243244 return location_functions , value_functions
244245
245- dirichlet_values = jnp .array (dirichlet_values )
246- van_neumann_values = jnp .array (van_neumann_values )
247-
248- print (f"dirichlet_values: { dirichlet_values } " )
249- print (f"van_neumann_values: { van_neumann_values } " )
246+ dirichlet_mask = jnp .array (dirichlet_mask )
247+ van_neumann_mask = jnp .array (van_neumann_mask )
250248
251249 dirichlet_location_fns , dirichlet_value_fns = bc_factory (
252250 dirichlet_mask , dirichlet_values
0 commit comments