Skip to content

Commit 90aa624

Browse files
committed
found error in bc setup
1 parent c101171 commit 90aa624

File tree

2 files changed

+60
-150
lines changed

2 files changed

+60
-150
lines changed

examples/ansys/fem_tess/tesseract_api.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ class OutputSchema(BaseModel):
7676
] = Field(description="Compliance of the structure, a measure of stiffness")
7777

7878

79+
# displacement: Array[
80+
# (None, 3),
81+
# Float32,
82+
# ] = Field(description="Nodal displacement field")
83+
7984
#
8085
# Helper functions
8186
#
@@ -140,7 +145,6 @@ def set_params(self, params: jnp.ndarray) -> None:
140145
# Override base class method.
141146
full_params = jnp.ones((self.fe.num_cells, params.shape[1]))
142147
full_params = full_params.at[self.fe.flex_inds].set(params)
143-
print(self.fe.num_quads)
144148
thetas = jnp.repeat(full_params[:, None, :], self.fe.num_quads, axis=1)
145149
self.full_params = full_params
146150
self.internal_vars = [thetas]
@@ -209,28 +213,25 @@ def bc_factory(
209213
# Create a factory that captures the current value of i
210214
def make_location_fn(idx):
211215
def location_fn(point, index):
212-
# jax.debug.print("Mask at point {}: {}", point, jax.lax.dynamic_index_in_dim(masks, index, 0, keepdims=False))
213216
return (
214-
jax.lax.dynamic_index_in_dim(masks, index, 0, keepdims=False)
215-
== idx
216-
)
217+
jnp.sum(
218+
jax.lax.dynamic_index_in_dim(
219+
masks, index, 0, keepdims=False
220+
)
221+
)
222+
== idx + 1
223+
).astype(jnp.bool_)
217224

218225
return location_fn
219226

220227
def make_value_fn(idx):
221228
def value_fn(point):
222-
# jax.debug.print("Value {} at point {}", jax.lax.dynamic_index_in_dim(values, idx, 0, keepdims=False), point)
223229
return values[idx]
224230

225231
return value_fn
226232

227233
def make_value_fn_vn(idx):
228234
def value_fn_vn(u, x):
229-
jax.debug.print(
230-
"Van Neumann Value {} at point {}",
231-
jax.lax.dynamic_index_in_dim(values, idx, 0, keepdims=False),
232-
x,
233-
)
234235
return values[idx]
235236

236237
return value_fn_vn
@@ -242,11 +243,8 @@ def value_fn_vn(u, x):
242243

243244
return location_functions, value_functions
244245

245-
dirichlet_values = jnp.array(dirichlet_values)
246-
van_neumann_values = jnp.array(van_neumann_values)
247-
248-
print(f"dirichlet_values: {dirichlet_values}")
249-
print(f"van_neumann_values: {van_neumann_values}")
246+
dirichlet_mask = jnp.array(dirichlet_mask)
247+
van_neumann_mask = jnp.array(van_neumann_mask)
250248

251249
dirichlet_location_fns, dirichlet_value_fns = bc_factory(
252250
dirichlet_mask, dirichlet_values

examples/ansys/optim_bars.ipynb

Lines changed: 46 additions & 134 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)