Skip to content

Commit a43bd4d

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Update nonlinear_constraint_is_feasible to return a boolean tensor (#2731)
Summary: Pull Request resolved: #2731 `nonlinear_constraint_is_feasible` checks whether the constraint is feasible for the given (batch of) candidates. Previously, this returned a boolean, which was False if any element of the batch was infeasible. This diff updates it to return a boolean tensor that shows whether each batch is feasible. Use cases are updated to comply with the new behavior. I'll utilize this in a follow up diff to introduce a helper that evaluates the feasibility of all forms of constraints. Reviewed By: dme65 Differential Revision: D69209007 fbshipit-source-id: 4f63ee6824734c3eb5c96bb5b1cf47a17d51b298
1 parent 8bda455 commit a43bd4d

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

botorch/generation/gen.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,21 @@ def f(x):
265265
# so it shouldn't be an issue given enough restarts.
266266
if nonlinear_inequality_constraints:
267267
for con, is_intrapoint in nonlinear_inequality_constraints:
268-
if not nonlinear_constraint_is_feasible(
269-
con, is_intrapoint=is_intrapoint, x=candidates
270-
):
271-
candidates = torch.from_numpy(x0).to(candidates).reshape(shapeX)
268+
if not (
269+
feasible := nonlinear_constraint_is_feasible(
270+
con, is_intrapoint=is_intrapoint, x=candidates
271+
)
272+
).all():
273+
# Replace the infeasible batches with feasible ICs.
274+
candidates[~feasible] = (
275+
torch.from_numpy(x0).to(candidates).reshape(shapeX)[~feasible]
276+
)
272277
warnings.warn(
273278
"SLSQP failed to converge to a solution the satisfies the "
274279
"non-linear constraints. Returning the feasible starting point.",
275280
OptimizationWarning,
276281
stacklevel=2,
277282
)
278-
break
279283

280284
clamped_candidates = columnwise_clamp(
281285
X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True

botorch/optim/parameter_constraints.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def f_grad(X):
512512

513513
def nonlinear_constraint_is_feasible(
514514
nonlinear_inequality_constraint: Callable, is_intrapoint: bool, x: Tensor
515-
) -> bool:
515+
) -> Tensor:
516516
"""Checks if a nonlinear inequality constraint is fulfilled.
517517
518518
Args:
@@ -522,23 +522,24 @@ def nonlinear_constraint_is_feasible(
522522
is applied pointwise and is broadcasted over the q-batch. Else, the
523523
constraint has to evaluated over the whole q-batch and is a an
524524
inter-point constraint.
525-
x: Tensor of shape (b x q x d).
525+
x: Tensor of shape (batch x q x d).
526526
527527
Returns:
528-
bool: True if the constraint is fulfilled, else False.
528+
A boolean tensor of shape (batch) indicating if the constraint is
529+
satified by the corresponding batch of `x`.
529530
"""
530531

531532
def check_x(x: Tensor) -> bool:
532533
return _arrayify(nonlinear_inequality_constraint(x)).item() >= NLC_TOL
533534

534-
for x_ in x:
535+
x_flat = x.view(-1, *x.shape[-2:])
536+
is_feasible = torch.ones(x_flat.shape[0], dtype=torch.bool, device=x.device)
537+
for i, x_ in enumerate(x_flat):
535538
if is_intrapoint:
536-
if not all(check_x(x__) for x__ in x_):
537-
return False
539+
is_feasible[i] &= all(check_x(x__) for x__ in x_)
538540
else:
539-
if not check_x(x_):
540-
return False
541-
return True
541+
is_feasible[i] &= check_x(x_)
542+
return is_feasible.view(x.shape[:-2])
542543

543544

544545
def make_scipy_nonlinear_inequality_constraints(
@@ -589,7 +590,7 @@ def make_scipy_nonlinear_inequality_constraints(
589590
nlc, is_intrapoint = constraint
590591
if not nonlinear_constraint_is_feasible(
591592
nlc, is_intrapoint=is_intrapoint, x=x0.reshape(shapeX)
592-
):
593+
).all():
593594
raise ValueError(
594595
"`batch_initial_conditions` must satisfy the non-linear inequality "
595596
"constraints."

test/optim/test_parameter_constraints.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,16 @@ def nlc(x):
358358
),
359359
)
360360
)
361-
self.assertFalse(
361+
self.assertEqual(
362362
nonlinear_constraint_is_feasible(
363363
nlc,
364364
True,
365365
torch.tensor(
366366
[[[1.5, 1.5], [1.5, 1.5]], [[1.5, 1.5], [1.5, 3.5]]],
367367
device=self.device,
368368
),
369-
)
369+
).tolist(),
370+
[True, False],
370371
)
371372
self.assertTrue(
372373
nonlinear_constraint_is_feasible(
@@ -381,22 +382,23 @@ def nlc(x):
381382
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]],
382383
device=self.device,
383384
),
384-
)
385+
).all()
385386
)
386387
self.assertFalse(
387388
nonlinear_constraint_is_feasible(
388389
nlc, False, torch.tensor([[[1.5, 1.5], [1.5, 1.5]]], device=self.device)
389390
)
390391
)
391-
self.assertFalse(
392+
self.assertEqual(
392393
nonlinear_constraint_is_feasible(
393394
nlc,
394395
False,
395396
torch.tensor(
396397
[[[1.0, 1.0], [1.0, 1.0]], [[1.5, 1.5], [1.5, 1.5]]],
397398
device=self.device,
398399
),
399-
)
400+
).tolist(),
401+
[True, False],
400402
)
401403

402404
def test_generate_unfixed_nonlin_constraints(self):

0 commit comments

Comments
 (0)