Skip to content

Commit 3a9bb59

Browse files
Move some gradient ascent optimizer flags into the gin config
This patch moves command line flags in es_trainer_lib to set the gradient ascent optimizer and gradient ascent optimizer flags to the gin config along with all the other model/optimizer hyperparameters. These don't make a lot of sense to have as command line parameters given their nature and everything else similar being defined in gin configs. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #433
1 parent 4b653c8 commit 3a9bb59

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

compiler_opt/es/es_trainer_lib.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Local ES trainer."""
1515

1616
from absl import flags, logging
17+
import enum
1718
import functools
1819
import gin
1920
import tensorflow as tf
@@ -31,22 +32,12 @@
3132

3233
FLAGS = flags.FLAGS
3334

34-
_BETA1 = flags.DEFINE_float("beta1", 0.9,
35-
"Beta1 for ADAM gradient ascent optimizer.")
36-
_BETA2 = flags.DEFINE_float("beta2", 0.999,
37-
"Beta2 for ADAM gradient ascent optimizer.")
3835
_GRAD_REG_ALPHA = flags.DEFINE_float(
3936
"grad_reg_alpha", 0.01,
4037
"Weight of regularization term in regression gradient.")
4138
_GRAD_REG_TYPE = flags.DEFINE_string(
4239
"grad_reg_type", "ridge",
4340
"Regularization method to use with regression gradient.")
44-
_GRADIENT_ASCENT_OPTIMIZER_TYPE = flags.DEFINE_string(
45-
"gradient_ascent_optimizer_type", None,
46-
"Gradient ascent optimization algorithm: 'momentum' or 'adam'")
47-
flags.mark_flag_as_required("gradient_ascent_optimizer_type")
48-
_MOMENTUM = flags.DEFINE_float(
49-
"momentum", 0.0, "Momentum for momentum gradient ascent optimizer.")
5041
_OUTPUT_PATH = flags.DEFINE_string("output_path", "",
5142
"Path to write all output")
5243
_PRETRAINED_POLICY_PATH = flags.DEFINE_string(
@@ -60,11 +51,22 @@
6051
"List of paths to training corpora")
6152

6253

54+
@gin.constants_from_enum(module="es_trainer_lib")
55+
class GradientAscentOptimizerType(enum.Enum):
56+
INVALID = 0
57+
MOMENTUM = enum.auto()
58+
ADAM = enum.auto()
59+
60+
6361
@gin.configurable
6462
def train(additional_compilation_flags=(),
6563
delete_compilation_flags=(),
6664
replace_compilation_flags=(),
67-
worker_class=None):
65+
worker_class=None,
66+
beta1=0.9,
67+
beta2=0.999,
68+
momentum=0.0,
69+
gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM):
6870
"""Train with ES."""
6971

7072
if not _TRAIN_CORPORA.value:
@@ -130,21 +132,20 @@ def train(additional_compilation_flags=(),
130132
# TODO(linzinan): delete all unused parameters.
131133

132134
# ------------------ GRADIENT ASCENT OPTIMIZERS ------------------------------
133-
if _GRADIENT_ASCENT_OPTIMIZER_TYPE.value == "momentum":
135+
if gradient_ascent_optimizer_type == GradientAscentOptimizerType.MOMENTUM:
134136
logging.info("Running momentum gradient ascent optimizer")
135137
# You can obtain a vanilla gradient ascent optimizer by setting momentum=0.0
136138
# and setting step_size to the desired learning rate.
137139
gradient_ascent_optimizer = (
138140
gradient_ascent_optimization_algorithms.MomentumOptimizer(
139-
learner_config.step_size, _MOMENTUM.value))
140-
elif _GRADIENT_ASCENT_OPTIMIZER_TYPE.value == "adam":
141+
learner_config.step_size, momentum))
142+
elif gradient_ascent_optimizer_type == GradientAscentOptimizerType.ADAM:
141143
logging.info("Running Adam gradient ascent optimizer")
142144
gradient_ascent_optimizer = (
143145
gradient_ascent_optimization_algorithms.AdamOptimizer(
144-
learner_config.step_size, _BETA1.value, _BETA2.value))
146+
learner_config.step_size, beta1, beta2))
145147
else:
146-
logging.info("No gradient ascent \
147-
optimizer selected. Stopping.")
148+
logging.info("No gradient ascent optimizer selected. Stopping.")
148149
return
149150
# ----------------------------------------------------------------------------
150151

0 commit comments

Comments
 (0)