Skip to content

Commit f943179

Browse files
Add base class for WorkerPoolManagers (#470)
This patch introduces a base class that worker pool managers can inherit from. This enforces the interface, particularly for instantiation, that we have been having a bit of trouble with recently given some recent refactorings. I've validated that this patch would have caught the issues that have already been fixed.
1 parent 8f1f1a9 commit f943179

File tree

8 files changed

+64
-9
lines changed

8 files changed

+64
-9
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
from absl import flags, logging
3939
# pylint: disable=unused-import
4040
from compiler_opt.distributed import worker
41+
from compiler_opt.distributed import worker_manager
4142

42-
from contextlib import AbstractContextManager
4343
from multiprocessing import connection
4444
from typing import Any
4545
from collections.abc import Callable
@@ -281,7 +281,7 @@ def close_local_worker_pool(pool: worker.FixedWorkerPool):
281281
stub.join()
282282

283283

284-
class LocalWorkerPoolManager(AbstractContextManager):
284+
class LocalWorkerPoolManager(worker_manager.WorkerManager):
285285
"""A pool of workers hosted on the local machines, each in its own process."""
286286

287287
def __init__(self,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""The interface for WorkerManager."""
15+
16+
import abc
17+
from contextlib import AbstractContextManager
18+
import pickle
19+
20+
from compiler_opt.distributed import worker
21+
22+
23+
class WorkerManager(AbstractContextManager, metaclass=abc.ABCMeta):
24+
"""An interface that implementations should derive from."""
25+
26+
@abc.abstractmethod
27+
def __init__(self,
28+
worker_class: type[worker.Worker],
29+
pickle_func=pickle.dumps,
30+
*,
31+
count: int | None,
32+
worker_args: tuple = (),
33+
worker_kwargs: dict | None = None):
34+
raise ValueError("Not Implemented")
35+
36+
@abc.abstractmethod
37+
def __enter__(self) -> worker.FixedWorkerPool:
38+
raise ValueError("Not Implemented")
39+
40+
@abc.abstractmethod
41+
def __exit__(self, *args):
42+
raise ValueError("Not Implemented")

compiler_opt/es/es_trainer_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# here as these errors are false positives.
2525
# pytype: disable=pyi-error
2626

27+
from compiler_opt.distributed import worker_manager
2728
from compiler_opt.distributed.local import local_worker_manager
2829
from compiler_opt.es import blackbox_optimizers
2930
from compiler_opt.es import gradient_ascent_optimization_algorithms
@@ -68,7 +69,9 @@ def train(additional_compilation_flags=(),
6869
beta2=0.999,
6970
momentum=0.0,
7071
gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM,
71-
worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
72+
worker_manager_class: type[
73+
worker_manager.WorkerManager] = local_worker_manager
74+
.LocalWorkerPoolManager):
7275
"""Train with ES."""
7376

7477
if not _TRAIN_CORPORA.value:

compiler_opt/rl/distributed/ppo_collect_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tf_agents.utils import common
3434
from tf_agents.trajectories import trajectory
3535

36+
from compiler_opt.distributed import worker_manager
3637
from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import
3738
from compiler_opt.rl import local_data_collector
3839
from compiler_opt.rl import corpus
@@ -102,7 +103,8 @@ def observe(self, result: compilation_runner.CompilationResult) -> None:
102103

103104
def collect(corpus_path: str, replay_buffer_server_address: str,
104105
variable_container_server_address: str, num_workers: int | None,
105-
worker_manager_class, sequence_length: int) -> None:
106+
worker_manager_class: type[worker_manager.WorkerManager],
107+
sequence_length: int) -> None:
106108
"""Collects experience using a policy updated after every episode.
107109
108110
Args:

compiler_opt/rl/distributed/ppo_eval_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tf_agents.train.utils import train_utils
2828
from tf_agents.utils import common
2929

30+
from compiler_opt.distributed import worker_manager
3031
from compiler_opt.rl import data_reader
3132
from compiler_opt.rl import local_data_collector
3233
from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import
@@ -39,7 +40,7 @@
3940

4041
def evaluate(root_dir: str, corpus_path: str,
4142
variable_container_server_address: str, num_workers: int | None,
42-
worker_manager_class):
43+
worker_manager_class: type[worker_manager.WorkerManager]):
4344
"""Evaluate a given policy on the given corpus.
4445
4546
Args:

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from compiler_opt.rl import env
4444

4545
from compiler_opt.distributed import worker
46+
from compiler_opt.distributed import worker_manager
4647
from compiler_opt.distributed import buffered_scheduler
4748
from compiler_opt.distributed.local import local_worker_manager
4849

@@ -900,7 +901,9 @@ def gen_trajectories(
900901
profiling_file_path: str | None = None,
901902
worker_wait_sec: float | None = None,
902903
worker_class_type=ModuleWorker,
903-
worker_manager_class=local_worker_manager.LocalWorkerPoolManager,
904+
worker_manager_class: type[
905+
worker_manager.WorkerManager] = local_worker_manager
906+
.LocalWorkerPoolManager,
904907
):
905908
"""Generates all trajectories for imitation learning training.
906909

compiler_opt/rl/train_locally.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tf_agents.agents import tf_agent
2828
from tf_agents.system import system_multiprocessing as multiprocessing
2929

30+
from compiler_opt.distributed import worker_manager
3031
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
3132
from compiler_opt.rl import agent_config
3233
from compiler_opt.rl import best_trajectory
@@ -58,7 +59,8 @@
5859

5960

6061
@gin.configurable
61-
def train_eval(worker_manager_class=LocalWorkerPoolManager,
62+
def train_eval(worker_manager_class: type[
63+
worker_manager.WorkerManager] = LocalWorkerPoolManager,
6264
agent_config_type=agent_config.PPOAgentConfig,
6365
warmstart_policy_dir=None,
6466
num_policy_iterations=0,

compiler_opt/tools/generate_default_trace.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import tensorflow as tf
2727

2828
from compiler_opt.distributed import worker
29+
from compiler_opt.distributed import worker_manager
2930
from compiler_opt.distributed import buffered_scheduler
3031
from compiler_opt.distributed.local import local_worker_manager
3132

@@ -112,8 +113,9 @@ def main(_):
112113
generate_trace()
113114

114115

115-
def generate_trace(
116-
worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
116+
def generate_trace(worker_manager_class: type[
117+
worker_manager.WorkerManager] = local_worker_manager.LocalWorkerPoolManager
118+
):
117119

118120
config = registry.get_configuration()
119121

0 commit comments

Comments
 (0)