diff --git a/tests/tools.py b/tests/tools.py index 3111839a37..209b5eb1c2 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -158,7 +158,7 @@ def metric_list(self, metric_prefix: str) -> List[str]: class RayUnittestBase(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init(ignore_reinit_error=True) + ray.init(ignore_reinit_error=True, namespace="trinity_unittest") @classmethod def tearDownClass(cls): diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 0ec438c2db..811a1ba64d 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -27,7 +27,6 @@ def setUp(self): self.config.model.model_path = get_model_path() self.config.explorer.rollout_model.engine_type = "vllm_async" self.config.algorithm.repeat_times = 3 - self.config.explorer.rollout_model.use_v1 = False self.config.project = "Trainer-unittest" self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.monitor.monitor_type = "tensorboard" @@ -45,6 +44,7 @@ class TestTrainerCountdown(BaseTrainerCase): def test_trainer(self): """Test the both and bench mode.""" # test both mode + self.config.explorer.rollout_model.use_v1 = False self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("countdown", "test") diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index f0ddea46c9..a3db72ef90 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -3,8 +3,6 @@ from copy import deepcopy from typing import List -import ray - from trinity.buffer.writer.file_writer import JSONWriter from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig @@ -20,7 +18,6 @@ def is_json_file(path: str) -> bool: return path.endswith(".json") or path.endswith(".jsonl") -@ray.remote class QueueActor: """An asyncio.Queue based queue actor.""" diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index b7cf06b2b5..71e9102999 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -55,6 +55,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): ray.remote(cls) .options( name=f"sql-{storage_config.name}", + namespace=ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) @@ -154,6 +155,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): ray.remote(cls) .options( name=f"json-{storage_config.name}", + namespace=ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index f696c6decb..271c2931e2 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -19,10 +19,15 @@ class QueueReader(BufferReader): def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE self.read_batch_size = config.read_batch_size - self.queue = QueueActor.options( - name=f"queue-{storage_config.name}", - get_if_exists=True, - ).remote(storage_config, config) + self.queue = ( + ray.remote(QueueActor) + .options( + name=f"queue-{storage_config.name}", + namespace=ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(storage_config, config) + ) def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 5cd24877d9..ec2316a0ec 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -18,10 +18,15 @@ class QueueWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config - self.queue = QueueActor.options( - name=f"queue-{meta.name}", - get_if_exists=True, - ).remote(meta, config) + self.queue = ( + ray.remote(QueueActor) + .options( + name=f"queue-{meta.name}", + namespace=ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(meta, config) + ) def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index d4f171803e..124475137c 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -20,7 +20,14 @@ def bench(config: Config) -> None: """Evaluate model.""" - explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) + explorer = ( + ray.remote(Explorer) + .options( + name=EXPLORER_NAME, + namespace=ray.get_runtime_context().namespace, + ) + .remote(config) + ) try: ray.get(explorer.prepare.remote()) ray.get(explorer.benchmark.remote()) @@ -34,7 +41,14 @@ def bench(config: Config) -> None: def explore(config: Config) -> None: """Run explorer.""" try: - explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) + explorer = ( + ray.remote(Explorer) + .options( + name=EXPLORER_NAME, + namespace=ray.get_runtime_context().namespace, + ) + .remote(config) + ) ray.get(explorer.prepare.remote()) ray.get(explorer.sync_weight.remote()) ray.get(explorer.explore.remote()) @@ -47,7 +61,14 @@ def explore(config: Config) -> None: def train(config: Config) -> None: """Run trainer.""" try: - trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config) + trainer = ( + ray.remote(Trainer) + .options( + name=TRAINER_NAME, + namespace=ray.get_runtime_context().namespace, + ) + .remote(config) + ) ray.get(trainer.prepare.remote()) ray.get(trainer.sync_weight.remote()) ray.get(trainer.train.remote()) @@ -67,8 +88,23 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ - explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) - trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config) + namespace = ray.get_runtime_context().namespace + explorer = ( + ray.remote(Explorer) + .options( + name=EXPLORER_NAME, + namespace=namespace, + ) + .remote(config) + ) + trainer = ( + ray.remote(Trainer) + .options( + name=TRAINER_NAME, + namespace=namespace, + ) + .remote(config) + ) ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()]) ray.get( [ @@ -191,17 +227,16 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): activate_data_module( f"{data_processor_config.data_processor_url}/experience_pipeline", config_path ) - ray_namespace = config.ray_namespace if dlc: from trinity.utils.dlc_utils import setup_ray_cluster - setup_ray_cluster(namespace=ray_namespace) + setup_ray_cluster(namespace=config.ray_namespace) else: from trinity.utils.dlc_utils import is_running if not is_running: raise RuntimeError("Ray is not running, please start it by `ray start --head`.") - ray.init(namespace=ray_namespace, ignore_reinit_error=True) + ray.init(namespace=config.ray_namespace, ignore_reinit_error=True) if config.mode == "explore": explore(config) elif config.mode == "train": @@ -214,7 +249,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if dlc: from trinity.utils.dlc_utils import stop_ray_cluster - stop_ray_cluster() + stop_ray_cluster(namespace=config.ray_namespace) def studio(port: int = 8501): diff --git a/trinity/common/config.py b/trinity/common/config.py index 04e90f00e9..bcce3bf217 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -181,7 +181,6 @@ class InferenceModelConfig: # ! DO NOT SET bundle_indices: str = "" - ray_namespace: str = "" @dataclass @@ -354,7 +353,7 @@ class Config: checkpoint_root_dir: str = "" # ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name` checkpoint_job_dir: str = "" - # ! DO NOT SET, automatically generated as f"{config.project}-{config.name}" + # If not set, automatically generated as f"{config.project}-{config.name}" ray_namespace: str = "" algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) @@ -579,7 +578,8 @@ def check_and_update(self) -> None: # noqa: C901 self._check_deprecated() # set namespace - self.ray_namespace = f"{self.project}-{self.name}" + if self.ray_namespace is None or len(self.ray_namespace) == 0: + self.ray_namespace = f"{self.project}-{self.name}" # check algorithm self._check_algorithm() @@ -611,9 +611,6 @@ def check_and_update(self) -> None: # noqa: C901 self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens if self.explorer.rollout_model.max_response_tokens is None: self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens - self.explorer.rollout_model.ray_namespace = self.ray_namespace - for model in self.explorer.auxiliary_models: - model.ray_namespace = self.ray_namespace # check synchronizer self.synchronizer.explorer_world_size = ( diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index fd5670b390..f9d092807c 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -89,7 +89,7 @@ def create_inference_models( for bundle_id, node_id in bundle_node_map.items(): node_bundle_map[node_id].append(bundle_id) allocator = _BundleAllocator(node_bundle_map) - + namespace = ray.get_runtime_context().namespace # create rollout models for _ in range(config.explorer.rollout_model.engine_num): bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size) @@ -101,6 +101,7 @@ def create_inference_models( .options( num_cpus=0, num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1, + namespace=namespace, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_capture_child_tasks=True, @@ -128,6 +129,7 @@ def create_inference_models( .options( num_cpus=0, num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, + namespace=namespace, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_capture_child_tasks=True, diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 3b3780b360..79c0cfae01 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import aiohttp +import ray import torch import vllm from vllm.sampling_params import RequestOutputKind @@ -298,7 +299,7 @@ async def init_process_group( timeout, update_with_checkpoint, state_dict_meta, - self.config.ray_namespace, + ray.get_runtime_context().namespace, ), ) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 9ade92ab1b..3efe88b000 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -10,6 +10,7 @@ import threading from typing import List, Optional, Tuple +import ray import torch import vllm from vllm import LLM @@ -112,7 +113,7 @@ def init_process_group( timeout, update_with_checkpoint, state_dict_meta, - self.config.ray_namespace, + ray.get_runtime_context().namespace, ), ) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 674027b690..2a156b8a2a 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -23,7 +23,7 @@ def init_process_group( timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: list = None, - namespace: str = "", + namespace: str = None, ): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" @@ -53,7 +53,7 @@ def init_process_group( group_name=group_name, ) logger.info("vLLM init_process_group finished.") - self.namespace = namespace + self._namespace = namespace self._explorer_actor = None def set_state_dict_meta(self, state_dict_meta): @@ -63,7 +63,7 @@ def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" assert self._state_dict_meta is not None if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self.namespace) + self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self._namespace) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: weight = ray.get(self._explorer_actor.get_weight.remote(name)) diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index 761d95965b..73c15ab4da 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -56,6 +56,7 @@ def __init__( ] self._idle_actors = list() self.actor_to_engine_index = {} + self._namespace = ray.get_runtime_context().namespace self._create_actors(config.explorer.runner_num) def _create_actors(self, num: int = 1): @@ -68,8 +69,17 @@ def _create_actors(self, num: int = 1): self.auxiliary_models, self.auxiliary_engine_status_list ) ] - new_actor = WorkflowRunner.remote( - self.config, self.models[engine_index], selected_auxiliary_models + new_actor = ( + ray.remote(WorkflowRunner) + .options( + namespace=self._namespace, + scheduling_strategy="SPREAD", + ) + .remote( + self.config, + self.models[engine_index], + selected_auxiliary_models, + ) ) new_actors.append(new_actor) self.engine_status[engine_index] += 1 @@ -219,7 +229,12 @@ def get_next(self) -> Status: ray.kill(a) # TODO: balance the model self._return_actor( - WorkflowRunner.remote( + ray.remote(WorkflowRunner) + .options( + namespace=self._namespace, + scheduling_strategy="SPREAD", + ) + .remote( self.config, self.models[ random.randint(0, self.config.explorer.rollout_model.engine_num - 1) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 3b36423ff6..b63a9ffadf 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -7,8 +7,6 @@ from dataclasses import dataclass from typing import List, Optional -import ray - from trinity.buffer import get_buffer_writer from trinity.common.config import Config from trinity.common.experience import Experience @@ -26,7 +24,6 @@ class Status: message: Optional[str] = None -@ray.remote(scheduling_strategy="SPREAD") class WorkflowRunner: """A Ray remote actor to run the workflow and put the returned experiences into the buffer.""" diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index b250d856d6..10f6f3b8bd 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -12,7 +12,6 @@ CLUSTER_ACTOR_NAME = "cluster_status" -@ray.remote class ClusterStatus: def __init__(self): self.finished = False @@ -97,10 +96,15 @@ def setup_ray_cluster(namespace: str): wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) else: # woker wait on the cluster status actor - cluster_status = ClusterStatus.options( - name=CLUSTER_ACTOR_NAME, - get_if_exists=True, - ).remote() + cluster_status = ( + ray.remote(ClusterStatus) + .options( + name=CLUSTER_ACTOR_NAME, + namespace=namespace, + get_if_exists=True, + ) + .remote() + ) while True: if ray.get(cluster_status.running.remote()): ret = subprocess.run("ray status", shell=True, capture_output=True) @@ -112,13 +116,18 @@ def setup_ray_cluster(namespace: str): sys.exit(0) -def stop_ray_cluster(): +def stop_ray_cluster(namespace: str): """ Stop the ray cluster by sending a signal to the cluster status actor. """ - cluster_status = ClusterStatus.options( - name=CLUSTER_ACTOR_NAME, - get_if_exists=True, - ).remote() + cluster_status = ( + ray.remote(ClusterStatus) + .options( + name=CLUSTER_ACTOR_NAME, + namespace=namespace, + get_if_exists=True, + ) + .remote() + ) ray.get(cluster_status.finish.remote()) logger.info("Stopping ray cluster...") diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index a5a779ae83..3a4e935d7d 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -60,6 +60,9 @@ def load_from_file(file_path: str): if full_module_name in sys.modules: raise ImportError(f"Module {module_name} already exists.") sys.modules[full_module_name] = module - shutil.copy2(file_path, Path(__file__).parent.parent / "plugins") + try: + shutil.copy2(file_path, Path(__file__).parent.parent / "plugins") + except shutil.SameFileError: + pass logger.info(f"Load {file_path} as {full_module_name}") return module