Skip to content

Commit 7c44a3e

Browse files
committed
fixed jnp -> mma -> jnp route
1 parent 60fcdec commit 7c44a3e

File tree

2 files changed

+367
-133
lines changed

2 files changed

+367
-133
lines changed

examples/ansys/pyvista_pymapdl.ipynb

Lines changed: 352 additions & 133 deletions
Large diffs are not rendered by default.

examples/ansys/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ def calculate_next_x(
273273
if x_max is None:
274274
x_max = self.x_max
275275

276+
# verify indputs
277+
x = self.__preprocess_jnp_array(x)
278+
x_min = self.__preprocess_jnp_array(x_min)
279+
x_max = self.__preprocess_jnp_array(x_max)
280+
constraint_values = self.__preprocess_jnp_array(constraint_values)
281+
objective_gradient = self.__preprocess_jnp_array(objective_gradient)
276282
self.__check_input_sizes(
277283
x,
278284
x_min,
@@ -316,6 +322,15 @@ def calculate_next_x(
316322

317323
return xmma
318324

325+
def __preprocess_jnp_array(
326+
self,
327+
x: jax.typing.ArrayLike,
328+
) -> np.array:
329+
np_x = np.array(x)
330+
if len(np_x.shape) == 1:
331+
np_x = np_x[:, None]
332+
return np_x
333+
319334
def __check_input_sizes(
320335
self,
321336
x: jax.typing.ArrayLike,

0 commit comments

Comments
 (0)