33It provides a modular framework for generating cutting planes.
44"""
55
6- from typing import Callable , Iterable
7- from typing import TypeAlias , Literal , overload
6+ from typing import Callable , Iterable , Literal , overload
87
98import mip
109import numpy as np
1110
12- from .utils import standard_basis_vector
11+ from .utils import check_scalar
1312
14- QueryPoint : TypeAlias = dict [mip .Var , float ]
15- Var : TypeAlias = mip .Var | Iterable [mip .Var ] | mip .LinExprTensor
16- Input : TypeAlias = float | Iterable [float ] | np .ndarray
17- Func : TypeAlias = Callable [[Input ], float ]
18- FuncGrad : TypeAlias = Callable [[Input ], tuple [float , float | np .ndarray ]]
19- Grad : TypeAlias = Callable [[Input ], float | np .ndarray ]
13+ type QueryPoint = dict [mip .Var , float ]
14+ type Var = mip .Var | Iterable [mip .Var ] | mip .LinExprTensor
15+ type Input = float | Iterable [float ] | np .ndarray
16+ type Func = Callable [[Input ], float ]
17+ type FuncGrad = Callable [[Input ], tuple [float , float | np .ndarray ]]
18+ type Grad = Callable [[Input ], float | np .ndarray ]
2019
2120
2221class ConvexTerm :
@@ -25,15 +24,15 @@ class ConvexTerm:
2524 Attributes:
2625 var: The variable(s) included in the term. This can be provided in the form of a single variable, an
2726 iterable of multiple variables or a variable tensor.
28- func: A function for computing the term's value. This function should except one argument for each
27+ func: A function for computing the term's value. This function should accept one argument for each
2928 variable in `var`. If `var` is a variable tensor, then the function should accept a single array.
30- grad: A function for computing the term's gradient. This function should except one argument for each
29+ grad: A function for computing the term's gradient. This function should accept one argument for each
3130 variable in `var`. If `var` is a variable tensor, then the function should accept a single array. If
3231 `None`, then the gradient is approximated numerically using the central finite difference method. If
3332 `grad` is instead a Boolean and is `True`, then `func` is assumed to return a tuple where the first
3433 element is the function value and the second element is the gradient. This is useful when the gradient
3534 is expensive to compute.
36- step_size: The step size used for numerical gradient approximation. If `grad` is provided, then this argument
35+ step_size: The step size used for numerical gradient approximation. Must be positive. If `grad` is provided, then this argument
3736 is ignored.
3837 name: The name for the term.
3938 """
@@ -49,12 +48,19 @@ def __init__(
4948 """Convex term constructor.
5049
5150 Args:
52- var: The value of the `var` attribute .
53- func: The value of the `func` attribute .
54- grad: The value of the `grad` attribute .
55- step_size: The value of the `step_size` attribute . Must be positive.
56- name: The value of the `name` attribute .
51+ var: The variable(s) included in the term. Can be a single variable, an iterable of variables, or a variable tensor .
52+ func: The function for computing the term's value .
53+ grad: The function for computing the term's gradient, or None for numerical approximation, or True if func returns (value, grad) .
54+ step_size: The step size for numerical gradient approximation . Must be positive.
55+ name: The name for the term .
5756 """
57+ check_scalar (
58+ x = step_size ,
59+ name = "step_size" ,
60+ var_type = float ,
61+ lb = 0 ,
62+ include_boundaries = False ,
63+ )
5864 self .var = var
5965 self .func = func
6066 self .grad = grad
@@ -77,7 +83,7 @@ def __call__(self, query_point: QueryPoint, return_grad: bool = False) -> float
7783 return_grad: Whether to return the term's gradient.
7884
7985 Returns:
80- If `return_grad=False`, then only the value of the term is returned. Conversely, if `return_grad=True`,
86+ If `return_grad=False`, then only the value of the term is returned. If `return_grad=True`,
8187 then a tuple is returned where the first element is the term's value and the second element is the term's
8288 gradient.
8389 """
@@ -91,34 +97,48 @@ def __call__(self, query_point: QueryPoint, return_grad: bool = False) -> float
9197
9298 @property
9399 def is_multivariable (self ) -> bool :
94- """Check whether the term is multivariable."""
100+ """Check whether the term is multivariable.
101+
102+ Returns:
103+ True if the term involves multiple variables, False otherwise.
104+ """
95105 return not isinstance (self .var , mip .Var )
96106
97107 def generate_cut (self , query_point : QueryPoint ) -> mip .LinExpr :
98108 """Generate a cutting plane for the term.
99109
110+ The cutting plane is a linear approximation of the convex term at the given query point,
111+ valid for all feasible points due to convexity.
112+
100113 Args:
101- query_point: dict mapping mip.Var to float
102- The query point for which the cutting plane is generated.
114+ query_point: The query point for which the cutting plane is generated.
103115
104116 Returns:
105- The linear constraint representing the cutting plane.
117+ A linear expression representing the cutting plane constraint .
106118 """
107- func , grad = self (query_point = query_point , return_grad = True )
119+ value , grad = self (query_point = query_point , return_grad = True )
108120 x = self ._get_input (query_point = query_point )
109121 if self .is_multivariable :
110- return mip .xsum (grad * (np .array (self .var ) - x )) + func
111- return grad * (self .var - x ) + func
122+ return mip .xsum (grad * (np .array (self .var ) - x )) + value
123+ return grad * (self .var - x ) + value
112124
113125 def _get_input (self , query_point : QueryPoint ) -> Input :
126+ """Extract input values from query point based on variable type.
127+
128+ Args:
129+ query_point: The query point containing variable values.
130+
131+ Returns:
132+ Input values in the format expected by the function.
133+ """
114134 if self .is_multivariable :
115135 return np .array ([query_point [var ] for var in self .var ])
116136 return query_point [self .var ]
117137
118138 def _evaluate_func (self , x : Input ) -> float | tuple [float , float | np .ndarray ]:
119139 """Evaluate the function value.
120140
121- If `grad=True`, then both the value of the function and it's gradient are returned.
141+ If `grad=True`, then both the value of the function and its gradient are returned.
122142 """
123143 if isinstance (self .var , (mip .Var , mip .LinExprTensor )):
124144 return self .func (x )
@@ -127,7 +147,14 @@ def _evaluate_func(self, x: Input) -> float | tuple[float, float | np.ndarray]:
127147 raise TypeError (f"Input of type '{ type (x )} ' not supported." )
128148
129149 def _evaluate_grad (self , x : Input ) -> float | np .ndarray :
130- """Evaluate the gradient."""
150+ """Evaluate the gradient.
151+
152+ Args:
153+ x: The input values at which to evaluate the gradient.
154+
155+ Returns:
156+ The gradient value(s).
157+ """
131158 if not self .grad :
132159 return self ._approximate_grad (x = x )
133160 if isinstance (self .var , (mip .Var , mip .LinExprTensor )):
@@ -137,15 +164,22 @@ def _evaluate_grad(self, x: Input) -> float | np.ndarray:
137164 raise TypeError (f"Input of type '{ type (x )} ' not supported." )
138165
139166 def _approximate_grad (self , x : Input ) -> float | np .ndarray :
140- """Approximate the gradient of the function at point using the central finite difference method."""
167+ """Approximate the gradient using central finite differences.
168+
169+ Args:
170+ x: The input values at which to approximate the gradient.
171+
172+ Returns:
173+ The approximated gradient value(s).
174+ """
141175 if self .is_multivariable :
142176 n_dim = len (x )
143177 grad = np .zeros (n_dim )
178+ e = np .eye (n_dim )
144179 for i in range (n_dim ):
145- e_i = standard_basis_vector (i = i , n_dim = n_dim )
146180 grad [i ] = (
147- self ._evaluate_func (x = x + self .step_size / 2 * e_i )
148- - self ._evaluate_func (x = x - self .step_size / 2 * e_i )
181+ self ._evaluate_func (x = x + self .step_size / 2 * e [ i ] )
182+ - self ._evaluate_func (x = x - self .step_size / 2 * e [ i ] )
149183 ) / self .step_size
150184 return grad
151185 return (
0 commit comments