Skip to content

Commit 0db6cd1

Browse files
Allow configuring the worker manager class in es_trainer_lib
This patch makes the worker manager class a parameter of es_trainer_lib.train so that it can be overriden within a gin config. This enables using alternative implementations like the ones internal to Google that enable taking advantage of internally available distributed resources. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #430
1 parent 3a9bb59 commit 0db6cd1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

compiler_opt/es/es_trainer_lib.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def train(additional_compilation_flags=(),
6666
beta1=0.9,
6767
beta2=0.999,
6868
momentum=0.0,
69-
gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM):
69+
gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM,
70+
worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
7071
"""Train with ES."""
7172

7273
if not _TRAIN_CORPORA.value:
@@ -215,8 +216,8 @@ def train(additional_compilation_flags=(),
215216
logging.info("Ready to train: running for %d steps.",
216217
learner_config.total_steps)
217218

218-
with local_worker_manager.LocalWorkerPoolManager(
219-
worker_class, learner_config.total_num_perturbations) as pool:
219+
with worker_manager_class(worker_class,
220+
learner_config.total_num_perturbations) as pool:
220221
for _ in range(learner_config.total_steps):
221222
learner.run_step(pool)
222223

0 commit comments

Comments
 (0)