11
11
from __future__ import annotations
12
12
13
13
from collections .abc import Callable
14
-
15
14
from functools import partial
16
15
from typing import Union
17
16
26
25
ScipyConstraintDict = dict [
27
26
str , Union [str , Callable [[np .ndarray ], float ], Callable [[np .ndarray ], np .ndarray ]]
28
27
]
29
- NLC_TOL = - 1e-6
28
+ CONST_TOL = 1e-6
30
29
31
30
32
31
def make_scipy_bounds (
@@ -511,9 +510,12 @@ def f_grad(X):
511
510
512
511
513
512
def nonlinear_constraint_is_feasible (
514
- nonlinear_inequality_constraint : Callable , is_intrapoint : bool , x : Tensor
513
+ nonlinear_inequality_constraint : Callable ,
514
+ is_intrapoint : bool ,
515
+ x : Tensor ,
516
+ tolerance : float = CONST_TOL ,
515
517
) -> Tensor :
516
- """Checks if a nonlinear inequality constraint is fulfilled.
518
+ """Checks if a nonlinear inequality constraint is fulfilled (within tolerance) .
517
519
518
520
Args:
519
521
nonlinear_inequality_constraint: Callable to evaluate the
@@ -523,14 +525,17 @@ def nonlinear_constraint_is_feasible(
523
525
constraint has to evaluated over the whole q-batch and is a an
524
526
inter-point constraint.
525
527
x: Tensor of shape (batch x q x d).
528
+ tolerance: Rather than using the exact `const(x) >= 0` constraint, this helper
529
+ checks feasibility of `const(x) >= -tolerance`. This avoids marking the
530
+ candidates as infeasible due to tiny violations.
526
531
527
532
Returns:
528
533
A boolean tensor of shape (batch) indicating if the constraint is
529
534
satified by the corresponding batch of `x`.
530
535
"""
531
536
532
537
def check_x (x : Tensor ) -> bool :
533
- return _arrayify (nonlinear_inequality_constraint (x )).item () >= NLC_TOL
538
+ return _arrayify (nonlinear_inequality_constraint (x )).item () >= - tolerance
534
539
535
540
x_flat = x .view (- 1 , * x .shape [- 2 :])
536
541
is_feasible = torch .ones (x_flat .shape [0 ], dtype = torch .bool , device = x .device )
@@ -603,3 +608,82 @@ def make_scipy_nonlinear_inequality_constraints(
603
608
shapeX = shapeX ,
604
609
)
605
610
return scipy_nonlinear_inequality_constraints
611
+
612
+
613
+ def evaluate_feasibility (
614
+ X : Tensor ,
615
+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
616
+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
617
+ nonlinear_inequality_constraints : list [tuple [Callable , bool ]] | None = None ,
618
+ tolerance : float = CONST_TOL ,
619
+ ) -> Tensor :
620
+ r"""Evaluate feasibility of candidate points (within a tolerance).
621
+
622
+ Args:
623
+ X: The candidate tensor of shape `batch x q x d`.
624
+ inequality_constraints: A list of tuples (indices, coefficients, rhs),
625
+ with each tuple encoding an inequality constraint of the form
626
+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
627
+ `coefficients` should be torch tensors. See the docstring of
628
+ `make_scipy_linear_constraints` for an example. When q=1, or when
629
+ applying the same constraint to each candidate in the batch
630
+ (intra-point constraint), `indices` should be a 1-d tensor.
631
+ For inter-point constraints, in which the constraint is applied to the
632
+ whole batch of candidates, `indices` must be a 2-d tensor, where
633
+ in each row `indices[i] =(k_i, l_i)` the first index `k_i` corresponds
634
+ to the `k_i`-th element of the `q`-batch and the second index `l_i`
635
+ corresponds to the `l_i`-th feature of that element.
636
+ equality_constraints: A list of tuples (indices, coefficients, rhs),
637
+ with each tuple encoding an equality constraint of the form
638
+ `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. See the docstring of
639
+ `make_scipy_linear_constraints` for an example.
640
+ nonlinear_inequality_constraints: A list of tuples representing the nonlinear
641
+ inequality constraints. The first element in the tuple is a callable
642
+ representing a constraint of the form `callable(x) >= 0`. In case of an
643
+ intra-point constraint, `callable()`takes in an one-dimensional tensor of
644
+ shape `d` and returns a scalar. In case of an inter-point constraint,
645
+ `callable()` takes a two dimensional tensor of shape `q x d` and again
646
+ returns a scalar. The second element is a boolean, indicating if it is an
647
+ intra-point or inter-point constraint (`True` for intra-point. `False` for
648
+ inter-point). For more information on intra-point vs inter-point
649
+ constraints, see the docstring of the `inequality_constraints` argument.
650
+ tolerance: The tolerance used to check the feasibility of equality constraints
651
+ and non-linear inequality constraints. For equality constraints, we check
652
+ if `abs(const(X) - rhs) < tolerance`. For non-linear inequality constraints,
653
+ we check if `const(X) >= -tolerance`. This avoids marking the candidates as
654
+ infeasible due to tiny violations.
655
+
656
+ Returns:
657
+ A boolean tensor of shape `batch` indicating if the corresponding candidate of
658
+ shape `q x d` is feasible.
659
+ """
660
+ is_feasible = torch .ones (X .shape [:- 2 ], device = X .device , dtype = torch .bool )
661
+ if inequality_constraints is not None :
662
+ for idx , coef , rhs in inequality_constraints :
663
+ if idx .ndim == 1 :
664
+ # Intra-point constraints.
665
+ is_feasible &= ((X [..., idx ] * coef ).sum (dim = - 1 ) >= rhs ).all (dim = - 1 )
666
+ else :
667
+ # Inter-point constraints.
668
+ is_feasible &= (X [..., idx [:, 0 ], idx [:, 1 ]] * coef ).sum (dim = - 1 ) >= rhs
669
+ if equality_constraints is not None :
670
+ for idx , coef , rhs in equality_constraints :
671
+ if idx .ndim == 1 :
672
+ # Intra-point constraints.
673
+ is_feasible &= (
674
+ ((X [..., idx ] * coef ).sum (dim = - 1 ) - rhs ).abs () < tolerance
675
+ ).all (dim = - 1 )
676
+ else :
677
+ # Inter-point constraints.
678
+ is_feasible &= (
679
+ (X [..., idx [:, 0 ], idx [:, 1 ]] * coef ).sum (dim = - 1 ) - rhs
680
+ ).abs () < tolerance
681
+ if nonlinear_inequality_constraints is not None :
682
+ for const , intra in nonlinear_inequality_constraints :
683
+ is_feasible &= nonlinear_constraint_is_feasible (
684
+ nonlinear_inequality_constraint = const ,
685
+ is_intrapoint = intra ,
686
+ x = X ,
687
+ tolerance = tolerance ,
688
+ )
689
+ return is_feasible
0 commit comments