|
14 | 14 | """Local ES trainer."""
|
15 | 15 |
|
16 | 16 | from absl import flags, logging
|
| 17 | +import enum |
17 | 18 | import functools
|
18 | 19 | import gin
|
19 | 20 | import tensorflow as tf
|
|
31 | 32 |
|
32 | 33 | FLAGS = flags.FLAGS
|
33 | 34 |
|
34 |
| -_BETA1 = flags.DEFINE_float("beta1", 0.9, |
35 |
| - "Beta1 for ADAM gradient ascent optimizer.") |
36 |
| -_BETA2 = flags.DEFINE_float("beta2", 0.999, |
37 |
| - "Beta2 for ADAM gradient ascent optimizer.") |
38 | 35 | _GRAD_REG_ALPHA = flags.DEFINE_float(
|
39 | 36 | "grad_reg_alpha", 0.01,
|
40 | 37 | "Weight of regularization term in regression gradient.")
|
41 | 38 | _GRAD_REG_TYPE = flags.DEFINE_string(
|
42 | 39 | "grad_reg_type", "ridge",
|
43 | 40 | "Regularization method to use with regression gradient.")
|
44 |
| -_GRADIENT_ASCENT_OPTIMIZER_TYPE = flags.DEFINE_string( |
45 |
| - "gradient_ascent_optimizer_type", None, |
46 |
| - "Gradient ascent optimization algorithm: 'momentum' or 'adam'") |
47 |
| -flags.mark_flag_as_required("gradient_ascent_optimizer_type") |
48 |
| -_MOMENTUM = flags.DEFINE_float( |
49 |
| - "momentum", 0.0, "Momentum for momentum gradient ascent optimizer.") |
50 | 41 | _OUTPUT_PATH = flags.DEFINE_string("output_path", "",
|
51 | 42 | "Path to write all output")
|
52 | 43 | _PRETRAINED_POLICY_PATH = flags.DEFINE_string(
|
|
60 | 51 | "List of paths to training corpora")
|
61 | 52 |
|
62 | 53 |
|
| 54 | +@gin.constants_from_enum(module="es_trainer_lib") |
| 55 | +class GradientAscentOptimizerType(enum.Enum): |
| 56 | + INVALID = 0 |
| 57 | + MOMENTUM = enum.auto() |
| 58 | + ADAM = enum.auto() |
| 59 | + |
| 60 | + |
63 | 61 | @gin.configurable
|
64 | 62 | def train(additional_compilation_flags=(),
|
65 | 63 | delete_compilation_flags=(),
|
66 | 64 | replace_compilation_flags=(),
|
67 |
| - worker_class=None): |
| 65 | + worker_class=None, |
| 66 | + beta1=0.9, |
| 67 | + beta2=0.999, |
| 68 | + momentum=0.0, |
| 69 | + gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM): |
68 | 70 | """Train with ES."""
|
69 | 71 |
|
70 | 72 | if not _TRAIN_CORPORA.value:
|
@@ -130,21 +132,20 @@ def train(additional_compilation_flags=(),
|
130 | 132 | # TODO(linzinan): delete all unused parameters.
|
131 | 133 |
|
132 | 134 | # ------------------ GRADIENT ASCENT OPTIMIZERS ------------------------------
|
133 |
| - if _GRADIENT_ASCENT_OPTIMIZER_TYPE.value == "momentum": |
| 135 | + if gradient_ascent_optimizer_type == GradientAscentOptimizerType.MOMENTUM: |
134 | 136 | logging.info("Running momentum gradient ascent optimizer")
|
135 | 137 | # You can obtain a vanilla gradient ascent optimizer by setting momentum=0.0
|
136 | 138 | # and setting step_size to the desired learning rate.
|
137 | 139 | gradient_ascent_optimizer = (
|
138 | 140 | gradient_ascent_optimization_algorithms.MomentumOptimizer(
|
139 |
| - learner_config.step_size, _MOMENTUM.value)) |
140 |
| - elif _GRADIENT_ASCENT_OPTIMIZER_TYPE.value == "adam": |
| 141 | + learner_config.step_size, momentum)) |
| 142 | + elif gradient_ascent_optimizer_type == GradientAscentOptimizerType.ADAM: |
141 | 143 | logging.info("Running Adam gradient ascent optimizer")
|
142 | 144 | gradient_ascent_optimizer = (
|
143 | 145 | gradient_ascent_optimization_algorithms.AdamOptimizer(
|
144 |
| - learner_config.step_size, _BETA1.value, _BETA2.value)) |
| 146 | + learner_config.step_size, beta1, beta2)) |
145 | 147 | else:
|
146 |
| - logging.info("No gradient ascent \ |
147 |
| - optimizer selected. Stopping.") |
| 148 | + logging.info("No gradient ascent optimizer selected. Stopping.") |
148 | 149 | return
|
149 | 150 | # ----------------------------------------------------------------------------
|
150 | 151 |
|
|
0 commit comments