Skip to content

Commit a59ca65

Browse files
Rename est_type to estimator_type
This max things more clear as est is not a common abbreviation. This is as suggested in #419. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #422
1 parent 3a4a297 commit a59ca65

File tree

6 files changed

+64
-60
lines changed

6 files changed

+64
-60
lines changed

compiler_opt/es/blackbox_evaluator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ class SamplingBlackboxEvaluator(BlackboxEvaluator):
6464
"""A blackbox evaluator that samples from a corpus to collect reward."""
6565

6666
def __init__(self, train_corpus: corpus.Corpus,
67-
est_type: blackbox_optimizers.EstimatorType,
67+
estimator_type: blackbox_optimizers.EstimatorType,
6868
total_num_perturbations: int, num_ir_repeats_within_worker: int):
6969
self._samples = []
7070
self._train_corpus = train_corpus
7171
self._total_num_perturbations = total_num_perturbations
7272
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
73-
self._est_type = est_type
73+
self._estimator_type = estimator_type
7474

7575
super().__init__(train_corpus)
7676

@@ -82,7 +82,8 @@ def get_results(
8282
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
8383
self._samples.append(sample)
8484
# add copy of sample for antithetic perturbation pair
85-
if self._est_type == (blackbox_optimizers.EstimatorType.ANTITHETIC):
85+
if self._estimator_type == (
86+
blackbox_optimizers.EstimatorType.ANTITHETIC):
8687
self._samples.append(sample)
8788

8889
compile_args = zip(perturbations, self._samples)
@@ -111,10 +112,10 @@ class TraceBlackboxEvaluator(BlackboxEvaluator):
111112
"""A blackbox evaluator that utilizes trace based cost modelling."""
112113

113114
def __init__(self, train_corpus: corpus.Corpus,
114-
est_type: blackbox_optimizers.EstimatorType, bb_trace_path: str,
115-
function_index_path: str):
115+
estimator_type: blackbox_optimizers.EstimatorType,
116+
bb_trace_path: str, function_index_path: str):
116117
self._train_corpus = train_corpus
117-
self._est_type = est_type
118+
self._estimator_type = estimator_type
118119
self._bb_trace_path = bb_trace_path
119120
self._function_index_path = function_index_path
120121

compiler_opt/es/blackbox_learner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class BlackboxLearnerConfig:
5757
# What kind of ES training?
5858
# - antithetic: for each perturbtation, try an antiperturbation
5959
# - forward_fd: try total_num_perturbations independent perturbations
60-
est_type: blackbox_optimizers.EstimatorType
60+
estimator_type: blackbox_optimizers.EstimatorType
6161

6262
# Should the rewards for blackbox optimization in a single step be normalized?
6363
fvalues_normalization: bool
@@ -164,7 +164,7 @@ def __init__(self,
164164
self._summary_writer = tf.summary.create_file_writer(output_dir)
165165

166166
self._evaluator = self._config.evaluator(self._train_corpus,
167-
self._config.est_type)
167+
self._config.estimator_type)
168168

169169
def _get_perturbations(self) -> List[npt.NDArray[np.float32]]:
170170
"""Get perturbations for the model weights."""
@@ -270,7 +270,8 @@ def run_step(self, pool: FixedWorkerPool) -> None:
270270

271271
initial_perturbations = self._get_perturbations()
272272
# positive-negative pairs
273-
if self._config.est_type == blackbox_optimizers.EstimatorType.ANTITHETIC:
273+
if (self._config.estimator_type ==
274+
blackbox_optimizers.EstimatorType.ANTITHETIC):
274275
initial_perturbations = [
275276
p for p in initial_perturbations for p in (p, -p)
276277
]

compiler_opt/es/blackbox_learner_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def setUp(self):
5353
self._learner_config = blackbox_learner.BlackboxLearnerConfig(
5454
total_steps=1,
5555
blackbox_optimizer=blackbox_optimizers.Algorithm.MONTE_CARLO,
56-
est_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
56+
estimator_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
5757
fvalues_normalization=True,
5858
hyperparameters_update_method=blackbox_optimizers.UpdateMethod
5959
.NO_METHOD,
@@ -117,7 +117,7 @@ def _policy_saver_fn(parameters: npt.NDArray[np.float32],
117117
self._learner = blackbox_learner.BlackboxLearner(
118118
blackbox_opt=blackbox_optimizers.MonteCarloBlackboxOptimizer(
119119
precision_parameter=1.0,
120-
est_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
120+
estimator_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
121121
normalize_fvalues=True,
122122
hyperparameters_update_method=blackbox_optimizers.UpdateMethod
123123
.NO_METHOD,

compiler_opt/es/blackbox_optimizers.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class UpdateMethod(enum.Enum):
124124

125125
def filter_top_directions(
126126
perturbations: FloatArray2D, function_values: FloatArray,
127-
est_type: EstimatorType,
127+
estimator_type: EstimatorType,
128128
num_top_directions: int) -> Tuple[FloatArray, FloatArray]:
129129
"""Select the subset of top-performing perturbations.
130130
@@ -134,7 +134,7 @@ def filter_top_directions(
134134
p, -p in the even/odd entries, so the directions p_1,...,p_n
135135
will be ordered (p_1, -p_1, p_2, -p_2,...)
136136
function_values: np array of reward values (maximization)
137-
est_type: (forward_fd | antithetic)
137+
estimator_type: (forward_fd | antithetic)
138138
num_top_directions: the number of top directions to include
139139
For antithetic, the total number of perturbations will
140140
be 2* this number, because we count p, -p as a single
@@ -148,16 +148,16 @@ def filter_top_directions(
148148
"""
149149
if not num_top_directions > 0:
150150
return (perturbations, function_values)
151-
if est_type == EstimatorType.FORWARD_FD:
151+
if estimator_type == EstimatorType.FORWARD_FD:
152152
top_index = np.argsort(-function_values)
153-
elif est_type == EstimatorType.ANTITHETIC:
153+
elif estimator_type == EstimatorType.ANTITHETIC:
154154
top_index = np.argsort(-np.abs(function_values[0::2] -
155155
function_values[1::2]))
156156
top_index = top_index[:num_top_directions]
157-
if est_type == EstimatorType.FORWARD_FD:
157+
if estimator_type == EstimatorType.FORWARD_FD:
158158
perturbations = perturbations[top_index]
159159
function_values = function_values[top_index]
160-
elif est_type == EstimatorType.ANTITHETIC:
160+
elif estimator_type == EstimatorType.ANTITHETIC:
161161
perturbations = np.concatenate(
162162
(perturbations[2 * top_index], perturbations[2 * top_index + 1]),
163163
axis=0)
@@ -245,11 +245,11 @@ class StatefulOptimizer(BlackboxOptimizer):
245245
Class contains common methods for handling the state.
246246
"""
247247

248-
def __init__(self, est_type: EstimatorType, normalize_fvalues: bool,
248+
def __init__(self, estimator_type: EstimatorType, normalize_fvalues: bool,
249249
hyperparameters_update_method: UpdateMethod,
250250
extra_params: Optional[Sequence[int]]):
251251

252-
self.est_type = est_type
252+
self.estimator_type = estimator_type
253253
self.normalize_fvalues = normalize_fvalues
254254
self.hyperparameters_update_method = hyperparameters_update_method
255255
if hyperparameters_update_method == UpdateMethod.STATE_NORMALIZATION:
@@ -321,7 +321,7 @@ class MonteCarloBlackboxOptimizer(StatefulOptimizer):
321321

322322
def __init__(self,
323323
precision_parameter: float,
324-
est_type: EstimatorType,
324+
estimator_type: EstimatorType,
325325
normalize_fvalues: bool,
326326
hyperparameters_update_method: UpdateMethod,
327327
extra_params: Optional[Sequence[int]],
@@ -342,8 +342,8 @@ def __init__(self,
342342
self.precision_parameter = precision_parameter
343343
self.num_top_directions = num_top_directions
344344
self.gradient_ascent_optimizer = gradient_ascent_optimizer
345-
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
346-
extra_params)
345+
super().__init__(estimator_type, normalize_fvalues,
346+
hyperparameters_update_method, extra_params)
347347

348348
# TODO: Issue #285
349349
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
@@ -358,14 +358,14 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
358358
function_values = np.array(normalized_values[:-1])
359359
current_value = normalized_values[-1]
360360
top_ps, top_fs = filter_top_directions(perturbations, function_values,
361-
self.est_type,
361+
self.estimator_type,
362362
self.num_top_directions)
363363
gradient = np.zeros(dim)
364364
for i, perturbation in enumerate(top_ps):
365365
function_value = top_fs[i]
366-
if self.est_type == EstimatorType.FORWARD_FD:
366+
if self.estimator_type == EstimatorType.FORWARD_FD:
367367
gradient_sample = (function_value - current_value) * perturbation
368-
elif self.est_type == EstimatorType.ANTITHETIC:
368+
elif self.estimator_type == EstimatorType.ANTITHETIC:
369369
gradient_sample = function_value * perturbation
370370
gradient_sample /= self.precision_parameter**2
371371
gradient += gradient_sample
@@ -374,7 +374,7 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
374374
# in that code, the denominator for antithetic was num_top_directions.
375375
# we maintain compatibility for now so that the same hyperparameters
376376
# currently used in Toaster will have the same effect
377-
if self.est_type == EstimatorType.ANTITHETIC and \
377+
if self.estimator_type == EstimatorType.ANTITHETIC and \
378378
len(top_ps) < len(perturbations):
379379
gradient *= 2
380380
# Use the gradient ascent optimizer to compute the next parameters with the
@@ -396,7 +396,7 @@ class SklearnRegressionBlackboxOptimizer(StatefulOptimizer):
396396
def __init__(self,
397397
regression_method: RegressionType,
398398
regularizer: float,
399-
est_type: EstimatorType,
399+
estimator_type: EstimatorType,
400400
normalize_fvalues: bool,
401401
hyperparameters_update_method: UpdateMethod,
402402
extra_params: Optional[Sequence[int]],
@@ -422,8 +422,8 @@ def __init__(self,
422422
else:
423423
raise ValueError('Optimization procedure option not available')
424424
self.gradient_ascent_optimizer = gradient_ascent_optimizer
425-
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
426-
extra_params)
425+
super().__init__(estimator_type, normalize_fvalues,
426+
hyperparameters_update_method, extra_params)
427427

428428
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
429429
current_input: FloatArray, current_value: float) -> FloatArray:
@@ -439,11 +439,11 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
439439

440440
matrix = None
441441
b_vector = None
442-
if self.est_type == EstimatorType.FORWARD_FD:
442+
if self.estimator_type == EstimatorType.FORWARD_FD:
443443
matrix = np.array(perturbations)
444444
b_vector = (
445445
function_values - np.array([current_value] * len(function_values)))
446-
elif self.est_type == EstimatorType.ANTITHETIC:
446+
elif self.estimator_type == EstimatorType.ANTITHETIC:
447447
matrix = np.array(perturbations[::2])
448448
function_even_values = np.array(function_values.tolist()[::2])
449449
function_odd_values = np.array(function_values.tolist()[1::2])
@@ -495,20 +495,20 @@ def normalize_function_values(
495495

496496

497497
def monte_carlo_gradient(precision_parameter: float,
498-
est_type: EstimatorType,
498+
estimator_type: EstimatorType,
499499
perturbations: FloatArray2D,
500500
function_values: FloatArray,
501501
current_value: float,
502502
energy: Optional[float] = 0) -> FloatArray:
503503
"""Calculates Monte Carlo gradient.
504504
505505
There are several ways of estimating the gradient. This is specified by the
506-
attribute self.est_type. Currently, forward finite difference (FFD) and
506+
attribute self.estimator_type. Currently, forward finite difference (FFD) and
507507
antithetic are supported.
508508
509509
Args:
510510
precision_parameter: sd of Gaussian perturbations
511-
est_type: 'forward_fd' (FFD) or 'antithetic'
511+
estimator_type: 'forward_fd' (FFD) or 'antithetic'
512512
perturbations: the simulated perturbations
513513
function_values: reward from perturbations (possibly normalized)
514514
current_value: estimated reward at current point (possibly normalized)
@@ -522,11 +522,11 @@ def monte_carlo_gradient(precision_parameter: float,
522522
"""
523523
dim = len(perturbations[0])
524524
b_vector = None
525-
if est_type == EstimatorType.FORWARD_FD:
525+
if estimator_type == EstimatorType.FORWARD_FD:
526526
b_vector = (function_values -
527527
np.array([current_value] * len(function_values))) / (
528528
precision_parameter * precision_parameter)
529-
elif est_type == EstimatorType.ANTITHETIC:
529+
elif estimator_type == EstimatorType.ANTITHETIC:
530530
b_vector = function_values / (2.0 * precision_parameter *
531531
precision_parameter)
532532
else:
@@ -543,15 +543,15 @@ def monte_carlo_gradient(precision_parameter: float,
543543
return gradient
544544

545545

546-
def sklearn_regression_gradient(clf: LinearModel, est_type: EstimatorType,
546+
def sklearn_regression_gradient(clf: LinearModel, estimator_type: EstimatorType,
547547
perturbations: FloatArray2D,
548548
function_values: FloatArray,
549549
current_value: float) -> FloatArray:
550550
"""Calculates gradient by function difference regression.
551551
552552
Args:
553553
clf: an object (SkLearn linear model) which fits Ax = b
554-
est_type: 'forward_fd' (FFD) or 'antithetic'
554+
estimator_type: 'forward_fd' (FFD) or 'antithetic'
555555
perturbations: the simulated perturbations
556556
function_values: reward from perturbations (possibly normalized)
557557
current_value: estimated reward at current point (possibly normalized)
@@ -565,11 +565,11 @@ def sklearn_regression_gradient(clf: LinearModel, est_type: EstimatorType,
565565
matrix = None
566566
b_vector = None
567567
dim = perturbations[0].size
568-
if est_type == EstimatorType.FORWARD_FD:
568+
if estimator_type == EstimatorType.FORWARD_FD:
569569
matrix = np.array(perturbations)
570570
b_vector = (
571571
function_values - np.array([current_value] * len(function_values)))
572-
elif est_type == EstimatorType.ANTITHETIC:
572+
elif estimator_type == EstimatorType.ANTITHETIC:
573573
matrix = np.array(perturbations[::2])
574574
function_even_values = np.array(function_values.tolist()[::2])
575575
function_odd_values = np.array(function_values.tolist()[1::2])
@@ -903,14 +903,14 @@ class TrustRegionOptimizer(StatefulOptimizer):
903903
schedule that would have to be tuned.
904904
"""
905905

906-
def __init__(self, precision_parameter: float, est_type: EstimatorType,
906+
def __init__(self, precision_parameter: float, estimator_type: EstimatorType,
907907
normalize_fvalues: bool,
908908
hyperparameters_update_method: UpdateMethod,
909909
extra_params: Optional[Sequence[int]], tr_params: Mapping[str,
910910
Any]):
911911
self.precision_parameter = precision_parameter
912-
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
913-
extra_params)
912+
super().__init__(estimator_type, normalize_fvalues,
913+
hyperparameters_update_method, extra_params)
914914

915915
self.accepted_quadratic_model = None
916916
self.accepted_function_value = None
@@ -1147,12 +1147,12 @@ def update_quadratic_model(self, perturbations: FloatArray2D,
11471147
current_value = normalized_values[1]
11481148
self.normalized_current_value = current_value
11491149
if self.params['grad_type'] == GradientType.REGRESSION:
1150-
new_gradient = sklearn_regression_gradient(self.clf, self.est_type,
1150+
new_gradient = sklearn_regression_gradient(self.clf, self.estimator_type,
11511151
perturbations, function_values,
11521152
current_value)
11531153
else:
11541154
new_gradient = monte_carlo_gradient(self.precision_parameter,
1155-
self.est_type, perturbations,
1155+
self.estimator_type, perturbations,
11561156
function_values, current_value)
11571157
new_gradient *= -1 # TR subproblem solver performs minimization
11581158
if not is_update:

compiler_opt/es/blackbox_optimizers_test.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ class BlackboxOptimizationAlgorithmsTest(parameterized.TestCase):
7272
blackbox_optimizers.EstimatorType.FORWARD_FD, 5,
7373
np.array([[4, 2], [8, -6], [-1, 5], [0, -3], [2, -1]
7474
]), np.array([10, 8, 4, 2, 1])))
75-
def test_filtering(self, perturbations, function_values, est_type,
75+
def test_filtering(self, perturbations, function_values, estimator_type,
7676
num_top_directions, expected_ps, expected_fs):
7777
top_ps, top_fs = blackbox_optimizers.filter_top_directions(
78-
perturbations, function_values, est_type, num_top_directions)
78+
perturbations, function_values, estimator_type, num_top_directions)
7979
np.testing.assert_array_equal(expected_ps, top_ps)
8080
np.testing.assert_array_equal(expected_fs, top_fs)
8181

@@ -88,13 +88,14 @@ def test_filtering(self, perturbations, function_values, est_type,
8888
blackbox_optimizers.EstimatorType.ANTITHETIC, 0, np.array([102, -34])),
8989
(perturbation_array, function_value_array,
9090
blackbox_optimizers.EstimatorType.FORWARD_FD, 0, np.array([74, -34])))
91-
def test_monte_carlo_gradient(self, perturbations, function_values, est_type,
92-
num_top_directions, expected_gradient):
91+
def test_monte_carlo_gradient(self, perturbations, function_values,
92+
estimator_type, num_top_directions,
93+
expected_gradient):
9394
precision_parameter = 0.1
9495
step_size = 0.01
9596
current_value = 2
9697
blackbox_object = blackbox_optimizers.MonteCarloBlackboxOptimizer(
97-
precision_parameter, est_type, False,
98+
precision_parameter, estimator_type, False,
9899
blackbox_optimizers.UpdateMethod.NO_METHOD, None, step_size,
99100
num_top_directions)
100101
current_input = np.zeros(2)
@@ -118,7 +119,7 @@ def test_monte_carlo_gradient(self, perturbations, function_values, est_type,
118119
(perturbation_array, function_value_array,
119120
blackbox_optimizers.EstimatorType.FORWARD_FD, 0, np.array([74, -34])))
120121
def test_monte_carlo_gradient_with_gradient_ascent_optimizer(
121-
self, perturbations, function_values, est_type, num_top_directions,
122+
self, perturbations, function_values, estimator_type, num_top_directions,
122123
expected_gradient):
123124
precision_parameter = 0.1
124125
step_size = 0.01
@@ -128,7 +129,7 @@ def test_monte_carlo_gradient_with_gradient_ascent_optimizer(
128129
step_size, 0.0))
129130
blackbox_object = (
130131
blackbox_optimizers.MonteCarloBlackboxOptimizer(
131-
precision_parameter, est_type, False,
132+
precision_parameter, estimator_type, False,
132133
blackbox_optimizers.UpdateMethod.NO_METHOD, None, None,
133134
num_top_directions, gradient_ascent_optimizer))
134135
current_input = np.zeros(2)
@@ -154,8 +155,9 @@ def test_monte_carlo_gradient_with_gradient_ascent_optimizer(
154155
(perturbation_array, function_value_array,
155156
blackbox_optimizers.EstimatorType.FORWARD_FD, 0,
156157
np.array([0.030203, 0.001796])))
157-
def test_sklearn_gradient(self, perturbations, function_values, est_type,
158-
num_top_directions, expected_gradient):
158+
def test_sklearn_gradient(self, perturbations, function_values,
159+
estimator_type, num_top_directions,
160+
expected_gradient):
159161
precision_parameter = 0.1
160162
step_size = 0.01
161163
current_value = 2
@@ -164,8 +166,8 @@ def test_sklearn_gradient(self, perturbations, function_values, est_type,
164166
gradient_ascent_optimization_algorithms.MomentumOptimizer(
165167
step_size, 0.0))
166168
blackbox_object = blackbox_optimizers.SklearnRegressionBlackboxOptimizer(
167-
blackbox_optimizers.RegressionType.RIDGE, regularizer, est_type, True,
168-
blackbox_optimizers.UpdateMethod.NO_METHOD, [], None,
169+
blackbox_optimizers.RegressionType.RIDGE, regularizer, estimator_type,
170+
True, blackbox_optimizers.UpdateMethod.NO_METHOD, [], None,
169171
gradient_ascent_optimizer)
170172
current_input = np.zeros(2)
171173
step = blackbox_object.run_step(perturbations, function_values,

0 commit comments

Comments
 (0)