Skip to content

Commit a12a099

Browse files
authored
[Bug Fix] Wrapping args of LocalWorkerPoolManager with worker_kwargs (#466)
Co-authored-by: svkeerthy <[email protected]>
1 parent c9ebbf2 commit a12a099

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

compiler_opt/rl/distributed/ppo_collect_lib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ def sequence_example_iterator_fn(seq_ex: list[str]):
167167
with worker_manager_class(
168168
worker_class=problem_config.get_runner_type(),
169169
count=num_workers,
170-
moving_average_decay_rate=1,
171-
create_observer_fns=create_observer_fns) as worker_pool:
170+
worker_kwargs=dict(
171+
moving_average_decay_rate=1,
172+
create_observer_fns=create_observer_fns)) as worker_pool:
172173

173174
data_collector = local_data_collector.LocalDataCollector(
174175
cps=cps,

compiler_opt/rl/distributed/ppo_eval_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def sequence_example_iterator_fn(seq_ex: list[str]):
105105
with worker_manager_class(
106106
worker_class=problem_config.get_runner_type(),
107107
count=num_workers,
108-
moving_average_decay_rate=1) as worker_pool:
108+
worker_kwargs=dict(moving_average_decay_rate=1)) as worker_pool:
109109
logging.info('constructed pool')
110110
collector = local_data_collector.LocalDataCollector(
111111
cps=cps,

compiler_opt/rl/train_locally.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def sequence_example_iterator_fn(seq_ex: list[str]):
146146
with worker_manager_class(
147147
worker_class=problem_config.get_runner_type(),
148148
count=FLAGS.num_workers,
149-
moving_average_decay_rate=moving_average_decay_rate) as worker_pool:
149+
worker_kwargs=dict(
150+
moving_average_decay_rate=moving_average_decay_rate)) as worker_pool:
150151
data_collector = local_data_collector.LocalDataCollector(
151152
cps=cps,
152153
num_modules=num_modules,

0 commit comments

Comments
 (0)