Skip to content

Commit 52a58d9

Browse files
authored
Edit perturbation type annotation and add algorithm enum (#287)
* Edit perturbation type annotation and add algorithm enum * Edit type annotations for consistency, make RegressionType gin configurable
1 parent ec3f4bc commit 52a58d9

File tree

1 file changed

+63
-60
lines changed

1 file changed

+63
-60
lines changed

compiler_opt/es/blackbox_optimizers.py

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
import abc
5656
import enum
57+
import gin
5758
import math
5859

5960
import numpy as np
@@ -64,7 +65,13 @@
6465

6566
from compiler_opt.es import gradient_ascent_optimization_algorithms
6667

67-
SequenceOfFloats = Union[Sequence[float], npt.NDArray[np.float32]]
68+
FloatArray = npt.NDArray[np.float32]
69+
70+
# should specifically be a 2d numpy array of floats
71+
# but numpy.typing does not allow for that indication
72+
FloatArray2D = Sequence[FloatArray]
73+
74+
SequenceOfFloats = Union[Sequence[float], FloatArray]
6875

6976
LinearModel = Union[linear_model.Ridge, linear_model.Lasso,
7077
linear_model.LinearRegression]
@@ -75,22 +82,33 @@ class CurrentPointEstimate(enum.Enum):
7582
AVERAGE = 2
7683

7784

85+
@gin.constants_from_enum(module='blackbox_optimizers')
86+
class Algorithm(enum.Enum):
87+
MONTE_CARLO = 1
88+
TRUST_REGION = 2
89+
SKLEARN_REGRESSION = 3
90+
91+
92+
@gin.constants_from_enum(module='blackbox_optimizers')
7893
class EstimatorType(enum.Enum):
7994
FORWARD_FD = 1
8095
ANTITHETIC = 2
8196

8297

98+
@gin.constants_from_enum(module='blackbox_optimizers')
8399
class GradientType(enum.Enum):
84100
MC = 1
85101
REGRESSION = 2
86102

87103

104+
@gin.constants_from_enum(module='blackbox_optimizers')
88105
class RegressionType(enum.Enum):
89106
LASSO = 1
90107
RIDGE = 2
91108
LINEAR = 3
92109

93110

111+
@gin.constants_from_enum(module='blackbox_optimizers')
94112
class UpdateMethod(enum.Enum):
95113
STATE_NORMALIZATION = 1
96114
NO_METHOD = 2
@@ -100,10 +118,9 @@ class UpdateMethod(enum.Enum):
100118

101119

102120
def filter_top_directions(
103-
perturbations: npt.NDArray[np.float32],
104-
function_values: npt.NDArray[np.float32], est_type: EstimatorType,
105-
num_top_directions: int
106-
) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
121+
perturbations: FloatArray2D, function_values: FloatArray,
122+
est_type: EstimatorType,
123+
num_top_directions: int) -> Tuple[FloatArray, FloatArray]:
107124
"""Select the subset of top-performing perturbations.
108125
109126
Args:
@@ -151,10 +168,8 @@ class BlackboxOptimizer(metaclass=abc.ABCMeta):
151168
"""
152169

153170
@abc.abstractmethod
154-
def run_step(self, perturbations: npt.NDArray[np.float32],
155-
function_values: npt.NDArray[np.float32],
156-
current_input: npt.NDArray[np.float32],
157-
current_value: float) -> npt.NDArray[np.float32]:
171+
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
172+
current_input: FloatArray, current_value: float) -> FloatArray:
158173
"""Conducts a single step of blackbox optimization procedure.
159174
160175
Conducts a single step of blackbox optimization procedure, given values of
@@ -332,10 +347,9 @@ def __init__(self,
332347
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
333348
extra_params)
334349

335-
def run_step(self, perturbations: npt.NDArray[np.float32],
336-
function_values: npt.NDArray[np.float32],
337-
current_input: npt.NDArray[np.float32],
338-
current_value: float) -> npt.NDArray[np.float32]:
350+
# TODO: Issue #285
351+
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
352+
current_input: FloatArray, current_value: float) -> FloatArray:
339353
dim = len(current_input)
340354
if self.normalize_fvalues:
341355
values = function_values.tolist()
@@ -413,10 +427,8 @@ def __init__(self,
413427
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
414428
extra_params)
415429

416-
def run_step(self, perturbations: npt.NDArray[np.float32],
417-
function_values: npt.NDArray[np.float32],
418-
current_input: npt.NDArray[np.float32],
419-
current_value: float) -> npt.NDArray[np.float32]:
430+
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
431+
current_input: FloatArray, current_value: float) -> FloatArray:
420432
dim = len(current_input)
421433
if self.normalize_fvalues:
422434
values = function_values.tolist()
@@ -474,8 +486,8 @@ def set_state(self, state: SequenceOfFloats) -> None:
474486

475487

476488
def normalize_function_values(
477-
function_values: npt.NDArray[np.float32],
478-
current_value: float) -> Tuple[npt.NDArray[np.float32], List[float]]:
489+
function_values: FloatArray,
490+
current_value: float) -> Tuple[FloatArray, List[float]]:
479491
values = function_values.tolist()
480492
values.append(current_value)
481493
mean = sum(values) / float(len(values))
@@ -484,13 +496,12 @@ def normalize_function_values(
484496
return (np.array(normalized_values[:-1]), normalized_values[-1])
485497

486498

487-
def monte_carlo_gradient(
488-
precision_parameter: float,
489-
est_type: EstimatorType,
490-
perturbations: npt.NDArray[np.float32],
491-
function_values: npt.NDArray[np.float32],
492-
current_value: float,
493-
energy: Optional[float] = 0) -> npt.NDArray[np.float32]:
499+
def monte_carlo_gradient(precision_parameter: float,
500+
est_type: EstimatorType,
501+
perturbations: FloatArray2D,
502+
function_values: FloatArray,
503+
current_value: float,
504+
energy: Optional[float] = 0) -> FloatArray:
494505
"""Calculates Monte Carlo gradient.
495506
496507
There are several ways of estimating the gradient. This is specified by the
@@ -530,11 +541,10 @@ def monte_carlo_gradient(
530541
return gradient
531542

532543

533-
def sklearn_regression_gradient(
534-
clf: LinearModel, est_type: EstimatorType,
535-
perturbations: npt.NDArray[np.float32],
536-
function_values: npt.NDArray[np.float32],
537-
current_value: float) -> npt.NDArray[np.float32]:
544+
def sklearn_regression_gradient(clf: LinearModel, est_type: EstimatorType,
545+
perturbations: FloatArray2D,
546+
function_values: FloatArray,
547+
current_value: float) -> FloatArray:
538548
"""Calculates gradient by function difference regression.
539549
540550
Args:
@@ -603,8 +613,8 @@ class QuadraticModel(object):
603613
# pylint: disable=invalid-name
604614
# argument Av should be capitalized as such for mathematical convention
605615
def __init__(self,
606-
Av: Callable[[npt.NDArray[np.float32]], npt.NDArray[np.float32]],
607-
b: npt.NDArray[np.float32],
616+
Av: Callable[[FloatArray], FloatArray],
617+
b: FloatArray,
608618
c: Optional[float] = 0):
609619
"""Initialize quadratic function.
610620
@@ -619,7 +629,7 @@ def __init__(self,
619629
self.c = c
620630

621631
# pylint: enable=invalid-name
622-
def f(self, x: npt.NDArray[np.float32]) -> float:
632+
def f(self, x: FloatArray) -> float:
623633
"""Evaluate the quadratic function.
624634
625635
Args:
@@ -629,7 +639,7 @@ def f(self, x: npt.NDArray[np.float32]) -> float:
629639
"""
630640
return 0.5 * np.dot(x, self.quad_v(x)) + np.dot(x, self.b) + self.c
631641

632-
def grad(self, x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
642+
def grad(self, x: FloatArray) -> FloatArray:
633643
"""Evaluate the gradient of the quadratic, Ax + b.
634644
635645
Args:
@@ -653,9 +663,8 @@ class ProjectedGradientOptimizer(object):
653663
"""
654664

655665
def __init__(self, function_object: QuadraticModel,
656-
projection_operator: Callable[[npt.NDArray[np.float32]],
657-
npt.NDArray[np.float32]],
658-
pgd_params: Mapping[str, Any], x_init: npt.NDArray[np.float32]):
666+
projection_operator: Callable[[FloatArray], FloatArray],
667+
pgd_params: Mapping[str, Any], x_init: FloatArray):
659668
self.f = function_object
660669
self.proj = projection_operator
661670
if pgd_params is not None:
@@ -698,7 +707,7 @@ def run_step(self) -> None:
698707
self.x = x_next
699708
self.k += 1
700709

701-
def get_solution(self) -> npt.NDArray[np.float32]:
710+
def get_solution(self) -> FloatArray:
702711
return self.x
703712

704713
def get_x_diff_norm(self) -> float:
@@ -708,9 +717,7 @@ def get_iterations(self) -> int:
708717
return self.k
709718

710719

711-
def make_projector(
712-
radius: float
713-
) -> Callable[[npt.NDArray[np.float32]], npt.NDArray[np.float32]]:
720+
def make_projector(radius: float) -> Callable[[FloatArray], FloatArray]:
714721
"""Makes an L2 projector function centered at origin.
715722
716723
Args:
@@ -719,7 +726,7 @@ def make_projector(
719726
A function of one argument that projects onto L2 ball.
720727
"""
721728

722-
def projector(w: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
729+
def projector(w: FloatArray) -> FloatArray:
723730
w_norm = np.linalg.norm(w, 2)
724731
if w_norm > radius:
725732
return radius / w_norm * w
@@ -748,7 +755,7 @@ class TrustRegionSubproblemOptimizer(object):
748755
def __init__(self,
749756
model_function: QuadraticModel,
750757
trust_region_params: Dict[str, Any],
751-
x_init: Optional[npt.NDArray[np.float32]] = None):
758+
x_init: Optional[FloatArray] = None):
752759
self.mf = model_function
753760
self.params = trust_region_params
754761
self.center = x_init
@@ -783,7 +790,7 @@ def solve_trust_region_subproblem(self) -> None:
783790

784791
self.x = pgd_solver.get_solution()
785792

786-
def get_solution(self) -> npt.NDArray[np.float32]:
793+
def get_solution(self) -> FloatArray:
787794
return self.x
788795

789796

@@ -913,7 +920,7 @@ def __init__(self, precision_parameter: float, est_type: EstimatorType,
913920
self.clf = linear_model.Lasso(alpha=self.params['grad_reg_alpha'])
914921
self.is_returned_step = False
915922

916-
def trust_region_test(self, current_input: npt.NDArray[np.float32],
923+
def trust_region_test(self, current_input: FloatArray,
917924
current_value: float) -> bool:
918925
"""Test the next step to determine how to update the trust region.
919926
@@ -981,9 +988,9 @@ def trust_region_test(self, current_input: npt.NDArray[np.float32],
981988
print('Unchanged: ' + str(self.radius) + log_message)
982989
return True
983990

984-
def update_hessian_part(self, perturbations: npt.NDArray[np.float32],
985-
function_values: npt.NDArray[np.float32],
986-
current_value: float, is_update: bool) -> None:
991+
def update_hessian_part(self, perturbations: FloatArray2D,
992+
function_values: FloatArray, current_value: float,
993+
is_update: bool) -> None:
987994
"""Updates the internal state which stores Hessian information.
988995
989996
Recall that the Hessian is given by
@@ -1046,13 +1053,12 @@ def update_hessian_part(self, perturbations: npt.NDArray[np.float32],
10461053
self.saved_function_values = np.append(self.saved_function_values,
10471054
function_values)
10481055

1049-
def create_hessv_function(
1050-
self) -> Callable[[npt.NDArray[np.float32]], npt.NDArray[np.float32]]:
1056+
def create_hessv_function(self) -> Callable[[FloatArray], FloatArray]:
10511057
"""Returns a function of one argument that evaluates Hessian-vector product.
10521058
"""
10531059
if self.params['dense_hessian']:
10541060

1055-
def hessv_func(x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
1061+
def hessv_func(x: FloatArray) -> FloatArray:
10561062
"""Calculates Hessian-vector product from dense Hessian.
10571063
10581064
Args:
@@ -1068,7 +1074,7 @@ def hessv_func(x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
10681074
return hessv
10691075
else:
10701076

1071-
def hessv_func(x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
1077+
def hessv_func(x: FloatArray) -> FloatArray:
10721078
"""Calculates Hessian-vector product from perturbation/value pairs.
10731079
10741080
Args:
@@ -1095,9 +1101,8 @@ def hessv_func(x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
10951101

10961102
return hessv_func
10971103

1098-
def update_quadratic_model(self, perturbations: npt.NDArray[np.float32],
1099-
function_values: npt.NDArray[np.float32],
1100-
current_value: float,
1104+
def update_quadratic_model(self, perturbations: FloatArray2D,
1105+
function_values: FloatArray, current_value: float,
11011106
is_update: bool) -> QuadraticModel:
11021107
"""Updates the internal state of the optimizer with new perturbations.
11031108
@@ -1145,10 +1150,8 @@ def update_quadratic_model(self, perturbations: npt.NDArray[np.float32],
11451150
is_update)
11461151
return QuadraticModel(self.create_hessv_function(), self.saved_gradient)
11471152

1148-
def run_step(self, perturbations: npt.NDArray[np.float32],
1149-
function_values: npt.NDArray[np.float32],
1150-
current_input: npt.NDArray[np.float32],
1151-
current_value: float) -> npt.NDArray[np.float32]:
1153+
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
1154+
current_input: FloatArray, current_value: float) -> FloatArray:
11521155
"""Run a single step of trust region optimizer.
11531156
11541157
Args:

0 commit comments

Comments
 (0)