Skip to content

Commit 9b1232f

Browse files
authored
Make the worker manager type a parameter to train_eval (#151)
This allows reusing `train_locally` with other worker managers. It's the minimum refactoring necessary - subsequent ones would make this a library and also remove the `_local` suffix from this and a few other places, since they aren't "local" in any sense anymore.
1 parent a02aa36 commit 9b1232f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

compiler_opt/rl/train_locally.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959

6060

6161
@gin.configurable
62-
def train_eval(agent_name=constant.AgentName.PPO,
62+
def train_eval(worker_manager_class=LocalWorkerPool,
63+
agent_name=constant.AgentName.PPO,
6364
warmstart_policy_dir=None,
6465
num_policy_iterations=0,
6566
num_modules=100,
@@ -133,7 +134,7 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
133134
logging.info('Loaded Reward Stat Map from disk, containing %d modules',
134135
len(reward_stat_map))
135136

136-
with LocalWorkerPool(
137+
with worker_manager_class(
137138
worker_class=problem_config.get_runner_type(),
138139
count=FLAGS.num_workers,
139140
moving_average_decay_rate=moving_average_decay_rate) as worker_pool:

0 commit comments

Comments
 (0)