Skip to content

Commit 5bf9e42

Browse files
authored
Bug Fix in namespace (#95)
1 parent acf7788 commit 5bf9e42

File tree

5 files changed

+15
-2
lines changed

5 files changed

+15
-2
lines changed

trinity/cli/launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
191191
activate_data_module(
192192
f"{data_processor_config.data_processor_url}/experience_pipeline", config_path
193193
)
194-
ray_namespace = f"{config.project}-{config.name}"
194+
ray_namespace = config.ray_namespace
195195
if dlc:
196196
from trinity.utils.dlc_utils import setup_ray_cluster
197197

trinity/common/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class InferenceModelConfig:
181181

182182
# ! DO NOT SET
183183
bundle_indices: str = ""
184+
ray_namespace: str = ""
184185

185186

186187
@dataclass
@@ -353,6 +354,8 @@ class Config:
353354
checkpoint_root_dir: str = ""
354355
# ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name`
355356
checkpoint_job_dir: str = ""
357+
# ! DO NOT SET, automatically generated as f"{config.project}-{config.name}"
358+
ray_namespace: str = ""
356359

357360
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
358361
data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig)
@@ -575,6 +578,9 @@ def check_and_update(self) -> None: # noqa: C901
575578
"""Check and update the config."""
576579
self._check_deprecated()
577580

581+
# set namespace
582+
self.ray_namespace = f"{self.project}-{self.name}"
583+
578584
# check algorithm
579585
self._check_algorithm()
580586

@@ -605,6 +611,9 @@ def check_and_update(self) -> None: # noqa: C901
605611
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
606612
if self.explorer.rollout_model.max_response_tokens is None:
607613
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
614+
self.explorer.rollout_model.ray_namespace = self.ray_namespace
615+
for model in self.explorer.auxiliary_models:
616+
model.ray_namespace = self.ray_namespace
608617

609618
# check synchronizer
610619
self.synchronizer.explorer_world_size = (

trinity/common/models/vllm_async_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ async def init_process_group(
298298
timeout,
299299
update_with_checkpoint,
300300
state_dict_meta,
301+
self.config.ray_namespace,
301302
),
302303
)
303304

trinity/common/models/vllm_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def init_process_group(
112112
timeout,
113113
update_with_checkpoint,
114114
state_dict_meta,
115+
self.config.ray_namespace,
115116
),
116117
)
117118

trinity/common/models/vllm_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def init_process_group(
2323
timeout: int = 1200,
2424
update_with_checkpoint: bool = True,
2525
state_dict_meta: list = None,
26+
namespace: str = "",
2627
):
2728
"""Init torch process group for model weights update"""
2829
assert torch.distributed.is_initialized(), "default torch process group must be initialized"
@@ -52,6 +53,7 @@ def init_process_group(
5253
group_name=group_name,
5354
)
5455
logger.info("vLLM init_process_group finished.")
56+
self.namespace = namespace
5557
self._explorer_actor = None
5658

5759
def set_state_dict_meta(self, state_dict_meta):
@@ -61,7 +63,7 @@ def update_weight(self):
6163
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
6264
assert self._state_dict_meta is not None
6365
if self._explorer_actor is None:
64-
self._explorer_actor = ray.get_actor(name=EXPLORER_NAME)
66+
self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self.namespace)
6567
for name, dtype_str, shape in self._state_dict_meta:
6668
if self._weight_update_rank == 0:
6769
weight = ray.get(self._explorer_actor.get_weight.remote(name))

0 commit comments

Comments
 (0)