diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 3ea4f0486f..d4f171803e 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -191,7 +191,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): activate_data_module( f"{data_processor_config.data_processor_url}/experience_pipeline", config_path ) - ray_namespace = f"{config.project}-{config.name}" + ray_namespace = config.ray_namespace if dlc: from trinity.utils.dlc_utils import setup_ray_cluster diff --git a/trinity/common/config.py b/trinity/common/config.py index f4480da311..04e90f00e9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -181,6 +181,7 @@ class InferenceModelConfig: # ! DO NOT SET bundle_indices: str = "" + ray_namespace: str = "" @dataclass @@ -353,6 +354,8 @@ class Config: checkpoint_root_dir: str = "" # ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name` checkpoint_job_dir: str = "" + # ! DO NOT SET, automatically generated as f"{config.project}-{config.name}" + ray_namespace: str = "" algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig) @@ -575,6 +578,9 @@ def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" self._check_deprecated() + # set namespace + self.ray_namespace = f"{self.project}-{self.name}" + # check algorithm self._check_algorithm() @@ -605,6 +611,9 @@ def check_and_update(self) -> None: # noqa: C901 self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens if self.explorer.rollout_model.max_response_tokens is None: self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens + self.explorer.rollout_model.ray_namespace = self.ray_namespace + for model in self.explorer.auxiliary_models: + model.ray_namespace = self.ray_namespace # check synchronizer self.synchronizer.explorer_world_size = ( diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 8a8a089afa..3b3780b360 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -298,6 +298,7 @@ async def init_process_group( timeout, update_with_checkpoint, state_dict_meta, + self.config.ray_namespace, ), ) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 878fe0bd9c..9ade92ab1b 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -112,6 +112,7 @@ def init_process_group( timeout, update_with_checkpoint, state_dict_meta, + self.config.ray_namespace, ), ) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 883e470381..674027b690 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -23,6 +23,7 @@ def init_process_group( timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: list = None, + namespace: str = "", ): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" @@ -52,6 +53,7 @@ def init_process_group( group_name=group_name, ) logger.info("vLLM init_process_group finished.") + self.namespace = namespace self._explorer_actor = None def set_state_dict_meta(self, state_dict_meta): @@ -61,7 +63,7 @@ def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" assert self._state_dict_meta is not None if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name=EXPLORER_NAME) + self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self.namespace) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: weight = ray.get(self._explorer_actor.get_weight.remote(name))