Skip to content

Commit e065a1d

Browse files
authored
Fix imports and remove extraneous comments (#258)
Fix imports and remove extraneous comments; Resolve pylint issues
1 parent 60e0c59 commit e065a1d

File tree

5 files changed

+67
-62
lines changed

5 files changed

+67
-62
lines changed

compiler_opt/es/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

compiler_opt/es/blackbox_optimizers.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
# https://arxiv.org/abs/1804.02395
2222
#
2323
###############################################################################
24-
"""
25-
first and second order blackbox optimizers
26-
"""
2724
r"""Library of blackbox optimization algorithms.
2825
2926
Library of stateful blackbox optimization algorithms taking as input the values
@@ -35,18 +32,16 @@
3532
import math
3633
import numpy as np
3734
from sklearn import linear_model
35+
from absl import flags
36+
import scipy.optimize as sp_opt
3837

39-
import gradient_ascent_optimization_algorithms
38+
from compiler_opt.es import gradient_ascent_optimization_algorithms
4039

4140

4241
def filter_top_directions(perturbations, function_values, est_type,
4342
num_top_directions):
4443
"""Select the subset of top-performing perturbations.
4544
46-
TODO(b/139662389): In the future, we may want (either here or inside the
47-
perturbation generator) to add assertions that Antithetic perturbations are
48-
delivered in the expected order (i.e (p_1, -p_1, p_2, -p_2,...)).
49-
5045
Args:
5146
perturbations: np array of perturbations
5247
For antithetic, it is assumed that the input puts the pair of
@@ -66,16 +61,16 @@ def filter_top_directions(perturbations, function_values, est_type,
6661
"""
6762
if not num_top_directions > 0:
6863
return (perturbations, function_values)
69-
if est_type == "forward_fd":
64+
if est_type == 'forward_fd':
7065
top_index = np.argsort(-function_values)
71-
elif est_type == "antithetic":
66+
elif est_type == 'antithetic':
7267
top_index = np.argsort(-np.abs(function_values[0::2] -
7368
function_values[1::2]))
7469
top_index = top_index[:num_top_directions]
75-
if est_type == "forward_fd":
70+
if est_type == 'forward_fd':
7671
perturbations = perturbations[top_index]
7772
function_values = function_values[top_index]
78-
elif est_type == "antithetic":
73+
elif est_type == 'antithetic':
7974
perturbations = np.concatenate(
8075
(perturbations[2 * top_index], perturbations[2 * top_index + 1]),
8176
axis=0)
@@ -109,7 +104,7 @@ def run_step(self, perturbations, function_values, current_input,
109104
New input obtained by conducting a single step of the blackbox
110105
optimization procedure.
111106
"""
112-
raise NotImplementedError("Abstract method")
107+
raise NotImplementedError('Abstract method')
113108

114109
@abc.abstractmethod
115110
def get_hyperparameters(self):
@@ -123,7 +118,7 @@ def get_hyperparameters(self):
123118
Returns:
124119
The set of hyperparameters for blackbox function runs.
125120
"""
126-
raise NotImplementedError("Abstract method")
121+
raise NotImplementedError('Abstract method')
127122

128123
@abc.abstractmethod
129124
def get_state(self):
@@ -136,7 +131,7 @@ def get_state(self):
136131
Returns:
137132
The state of the optimizer.
138133
"""
139-
raise NotImplementedError("Abstract method")
134+
raise NotImplementedError('Abstract method')
140135

141136
@abc.abstractmethod
142137
def update_state(self, evaluation_stats):
@@ -149,7 +144,7 @@ def update_state(self, evaluation_stats):
149144
150145
Returns:
151146
"""
152-
raise NotImplementedError("Abstract method")
147+
raise NotImplementedError('Abstract method')
153148

154149
@abc.abstractmethod
155150
def set_state(self, state):
@@ -162,7 +157,7 @@ def set_state(self, state):
162157
163158
Returns:
164159
"""
165-
raise NotImplementedError("Abstract method")
160+
raise NotImplementedError('Abstract method')
166161

167162

168163
class MCBlackboxOptimizer(BlackboxOptimizer):
@@ -178,9 +173,9 @@ def __init__(self,
178173
num_top_directions=0,
179174
ga_optimizer=None):
180175
# Check step_size and ga_optimizer
181-
if bool(step_size) == bool(ga_optimizer):
176+
if (step_size is None) == (ga_optimizer is None):
182177
raise ValueError(
183-
"Exactly one of step_size and ga_optimizer should be provided")
178+
'Exactly one of step_size and ga_optimizer should be provided')
184179
if step_size:
185180
ga_optimizer = gradient_ascent_optimization_algorithms.MomentumOptimizer(
186181
step_size=step_size, momentum=0.0)
@@ -190,7 +185,7 @@ def __init__(self,
190185
self.normalize_fvalues = normalize_fvalues
191186
self.hyperparameters_update_method = hyperparameters_update_method
192187
self.num_top_directions = num_top_directions
193-
if hyperparameters_update_method == "state_normalization":
188+
if hyperparameters_update_method == 'state_normalization':
194189
self.state_dim = extra_params[0]
195190
self.nb_steps = 0
196191
self.sum_state_vector = [0.0] * self.state_dim
@@ -217,9 +212,9 @@ def run_step(self, perturbations, function_values, current_input,
217212
gradient = np.zeros(dim)
218213
for i, perturbation in enumerate(top_ps):
219214
function_value = top_fs[i]
220-
if self.est_type == "forward_fd":
215+
if self.est_type == 'forward_fd':
221216
gradient_sample = (function_value - current_value) * perturbation
222-
elif self.est_type == "antithetic":
217+
elif self.est_type == 'antithetic':
223218
gradient_sample = function_value * perturbation
224219
gradient_sample /= self.precision_parameter**2
225220
gradient += gradient_sample
@@ -228,21 +223,21 @@ def run_step(self, perturbations, function_values, current_input,
228223
# in that code, the denominator for antithetic was num_top_directions.
229224
# we maintain compatibility for now so that the same hyperparameters
230225
# currently used in Toaster will have the same effect
231-
if self.est_type == "antithetic" and len(top_ps) < len(perturbations):
226+
if self.est_type == 'antithetic' and len(top_ps) < len(perturbations):
232227
gradient *= 2
233228
# Use the gradient ascent optimizer to compute the next parameters with the
234229
# gradients
235230
return self.ga_optimizer.run_step(current_input, gradient)
236231

237232
def get_hyperparameters(self):
238-
if self.hyperparameters_update_method == "state_normalization":
233+
if self.hyperparameters_update_method == 'state_normalization':
239234
return self.mean_state_vector + self.std_state_vector
240235
else:
241236
return []
242237

243238
def get_state(self):
244239
ga_state = self.ga_optimizer.get_state()
245-
if self.hyperparameters_update_method == "state_normalization":
240+
if self.hyperparameters_update_method == 'state_normalization':
246241
current_state = [self.nb_steps]
247242
current_state += self.sum_state_vector
248243
current_state += self.squares_state_vector
@@ -252,7 +247,7 @@ def get_state(self):
252247
return ga_state
253248

254249
def update_state(self, evaluation_stats):
255-
if self.hyperparameters_update_method == "state_normalization":
250+
if self.hyperparameters_update_method == 'state_normalization':
256251
self.nb_steps += evaluation_stats[0]
257252
evaluation_stats = evaluation_stats[1:]
258253
first_half = evaluation_stats[:self.state_dim]
@@ -276,7 +271,7 @@ def update_state(self, evaluation_stats):
276271
]
277272

278273
def set_state(self, state):
279-
if self.hyperparameters_update_method == "state_normalization":
274+
if self.hyperparameters_update_method == 'state_normalization':
280275
self.nb_steps = state[0]
281276
state = state[1:]
282277
self.sum_state_vector = state[:self.state_dim]
@@ -290,17 +285,14 @@ def set_state(self, state):
290285
self.ga_optimizer.set_state(state)
291286

292287

293-
"""
294-
secondorder optimizers
295-
"""
296-
r"""Experimental optimizers based on blackbox ES.
288+
# pylint: disable=pointless-string-statement
289+
"""Secondorder optimizers.
290+
291+
Experimental optimizers based on blackbox ES.
297292
298293
See class descriptions for more detailed notes on each algorithm.
299294
"""
300295

301-
from absl import flags
302-
import scipy.optimize as sp_opt
303-
304296
_GRAD_TYPE = flags.DEFINE_string('grad_type', 'MC', 'Gradient estimator.')
305297
_TR_INIT_RADIUS = flags.DEFINE_float('tr_init_radius', 1,
306298
'Initial radius for TR method.')
@@ -328,8 +320,6 @@ def set_state(self, state):
328320
'Minimum radius of trust region.')
329321

330322
DEFAULT_ARMIJO = 1e-4
331-
332-
# pylint: disable=pointless-string-statement
333323
"""Gradient estimators.
334324
The blackbox pipeline has two steps:
335325
estimate gradient/Hessian --> optimizer --> next weight
@@ -461,6 +451,8 @@ class QuadraticModel(object):
461451
f(x) = 1/2x^TAx + b^Tx + c
462452
"""
463453

454+
# pylint: disable=invalid-name
455+
# argument Av should be capitalized as such for mathematical convention
464456
def __init__(self, Av, b, c=0):
465457
"""Initialize quadratic function.
466458
@@ -474,6 +466,7 @@ def __init__(self, Av, b, c=0):
474466
self.b = b
475467
self.c = c
476468

469+
# pylint: enable=invalid-name
477470
def f(self, x):
478471
"""Evaluate the quadratic function.
479472
@@ -793,8 +786,7 @@ def trust_region_test(self, current_input, current_value):
793786
absolute_max_reward = abs(current_value)
794787
else:
795788
absolute_max_reward = abs(self.accepted_function_value)
796-
if absolute_max_reward < 1e-8:
797-
absolute_max_reward = 1e-8
789+
absolute_max_reward = max(absolute_max_reward, 1e-08)
798790
abs_ratio = (
799791
abs(current_value - self.accepted_function_value) /
800792
absolute_max_reward)
@@ -1028,8 +1020,8 @@ def run_step(self, perturbations, function_values, current_input,
10281020
# This point was just returned to after rejecting a step.
10291021
# We update the model by averaging the previous gradient/Hessian
10301022
# with the current perturbations. Then we set is_returned_step to False
1031-
# in preparation for taking the next step after re-solving the trust region
1032-
# at this point again, with smaller radius. """
1023+
# in preparation for taking the next step after re-solving the
1024+
# trust region at this point again, with smaller radius. """
10331025
self.accepted_quadratic_model = mf
10341026
self.accepted_weights = current_input
10351027
# This step has been accepted, so store the most recent quadratic model

compiler_opt/es/blackbox_optimizers_test.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
# https://arxiv.org/abs/1804.02395
2222
#
2323
###############################################################################
24-
"""Tests for google3.learning.brain.contrib.blackbox.blackbox_optimization_algorithms."""
24+
r"""Tests for blackbox_optimization_algorithms."""
2525

2626
import numpy as np
2727

28-
import blackbox_optimizers as bo
2928
from absl.testing import absltest
3029
from absl.testing import parameterized
31-
import gradient_ascent_optimization_algorithms
30+
31+
from compiler_opt.es import blackbox_optimizers as bo
32+
from compiler_opt.es import gradient_ascent_optimization_algorithms
3233

3334
perturbation_array = np.array([[0, 1], [2, -1], [4, 2],
3435
[-2, -2], [0, 3], [0, -3], [0, 4], [0, -4],
@@ -109,9 +110,6 @@ def test_mc_gradient_with_ga_optimizer(self, perturbations, function_values,
109110
np.testing.assert_array_almost_equal(expected_gradient, gradient)
110111

111112

112-
"""Tests for google3.learning.brain.contrib.blackbox.secondorder_blackbox_optimizers."""
113-
114-
115113
class SecondorderBlackboxOptimizersTest(absltest.TestCase):
116114

117115
class GenericFunction(object):
@@ -126,7 +124,7 @@ def setUp(self):
126124
The matrix A is indefinite and has eigs
127125
[2.15, 0.53, -2.67]
128126
"""
129-
super(SecondorderBlackboxOptimizersTest, self).setUp()
127+
super().setUp()
130128
# pylint: disable=bad-whitespace,invalid-name
131129
self.A = np.array([[1, -1, 0], [-1, 0, 2], [0, 2, -1]])
132130
self.b = np.array([1, 0, 1])
@@ -163,9 +161,15 @@ def testProjectedGradientOptimizer(self):
163161
over the nonnegative orthant.
164162
The exact solution is (0,1).
165163
"""
166-
cost_function = lambda x: (x[0] + 1)**2 + (x[1] - 1)**2
167-
cost_gradient = lambda x: np.array([2 * (x[0] + 1), 2 * (x[1] - 1)])
168-
projector = lambda x: np.maximum(0, x)
164+
165+
def cost_function(x):
166+
return (x[0] + 1)**2 + (x[1] - 1)**2
167+
168+
def cost_gradient(x):
169+
return np.array([2 * (x[0] + 1), 2 * (x[1] - 1)])
170+
171+
def projector(x):
172+
return np.maximum(0, x)
169173

170174
objective_function = SecondorderBlackboxOptimizersTest.GenericFunction()
171175
objective_function.f = cost_function

compiler_opt/es/gradient_ascent_optimization_algorithms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import numpy as np
3232

3333

34-
# TODO(kchoro): Borrow JAXs optimizer library here. Integrated into Blackbox-v2.
3534
class GAOptimizer(metaclass=abc.ABCMeta):
3635
"""Abstract class for general gradient ascent optimizers.
3736
@@ -124,9 +123,10 @@ def set_state(self, state):
124123

125124
class AdamOptimizer(GAOptimizer):
126125
"""Class implementing ADAM gradient ascent optimizer.
127-
128-
The state is the first moment moving average, the second moment moving average,
129-
and t (current step number) combined in that order into one list
126+
127+
The state is the first moment moving average, the second
128+
moment moving average, and t (current step number)
129+
combined in that order into one list
130130
"""
131131

132132
def __init__(self, step_size, beta1=0.9, beta2=0.999, epsilon=1e-07):

compiler_opt/es/gradient_ascent_optimization_algorithms_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,15 @@
2121
# https://arxiv.org/abs/1804.02395
2222
#
2323
###############################################################################
24-
"""Tests for google3.learning.brain.contrib.blackbox.gradient_ascent_optimization_algorithms."""
24+
r"""Tests for gradient_ascent_optimization_algorithms."""
2525

2626
import numpy as np
2727

28-
# from google3.learning.brain.contrib.blackbox import gradient_ascent_optimization_algorithms
29-
# from google3.testing.pybase import googletest
30-
31-
# from google3.testing.pybase import parameterized
32-
import gradient_ascent_optimization_algorithms
3328
from absl.testing import absltest
34-
3529
from absl.testing import parameterized
3630

31+
from compiler_opt.es import gradient_ascent_optimization_algorithms
32+
3733

3834
class GradientAscentOptimizationAlgorithmsTest(parameterized.TestCase):
3935

@@ -85,5 +81,4 @@ def test_adam_step(self, step_size, beta1, beta2, ini_parameter, gradient1,
8581

8682

8783
if __name__ == '__main__':
88-
# googletest.main()
8984
absltest.main()

0 commit comments

Comments
 (0)