File tree Expand file tree Collapse file tree 2 files changed +367
-133
lines changed Expand file tree Collapse file tree 2 files changed +367
-133
lines changed Load Diff Large diffs are not rendered by default.
Original file line number Diff line number Diff line change @@ -273,6 +273,12 @@ def calculate_next_x(
273273 if x_max is None :
274274 x_max = self .x_max
275275
276+ # verify indputs
277+ x = self .__preprocess_jnp_array (x )
278+ x_min = self .__preprocess_jnp_array (x_min )
279+ x_max = self .__preprocess_jnp_array (x_max )
280+ constraint_values = self .__preprocess_jnp_array (constraint_values )
281+ objective_gradient = self .__preprocess_jnp_array (objective_gradient )
276282 self .__check_input_sizes (
277283 x ,
278284 x_min ,
@@ -316,6 +322,15 @@ def calculate_next_x(
316322
317323 return xmma
318324
325+ def __preprocess_jnp_array (
326+ self ,
327+ x : jax .typing .ArrayLike ,
328+ ) -> np .array :
329+ np_x = np .array (x )
330+ if len (np_x .shape ) == 1 :
331+ np_x = np_x [:, None ]
332+ return np_x
333+
319334 def __check_input_sizes (
320335 self ,
321336 x : jax .typing .ArrayLike ,
You can’t perform that action at this time.
0 commit comments