Skip to content

Commit e4446c1

Browse files
Add num_workers flag to es_trainer (#473)
This patch adds a num_workers flag to es_trainer. This better matches the behavior of the train_locally script, which is important for some internal scripts for distributed training. Choosing the number of workers based on the number of perturbations also does not take into consideration the underlying hardware at all, which should be what determines the worker count. The current logic already did not take into account antithetic sampling doubling the number of models to evaluate per iteration.
1 parent f86feef commit e4446c1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

compiler_opt/es/es_trainer_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
to the data collection requests.")
5252
_TRAIN_CORPORA = flags.DEFINE_string("train_corpora", "",
5353
"List of paths to training corpora")
54+
_NUM_WORKERS = flags.DEFINE_integer("num_workers", 100,
55+
"The number of workers to create.")
5456

5557

5658
@gin.constants_from_enum(module="es_trainer_lib")
@@ -216,7 +218,7 @@ def train(additional_compilation_flags=(),
216218

217219
with worker_manager_class(
218220
worker_class,
219-
count=learner_config.total_num_perturbations,
221+
count=_NUM_WORKERS.value,
220222
worker_kwargs=dict(gin_config=gin.operative_config_str())) as pool:
221223
learner.set_baseline(pool)
222224
for _ in range(learner_config.total_steps):

0 commit comments

Comments
 (0)