Skip to content

Commit 8e4ce25

Browse files
authored
Unify ray actor creation method (#97)
1 parent b8bd0ba commit 8e4ce25

File tree

16 files changed

+120
-51
lines changed

16 files changed

+120
-51
lines changed

tests/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def metric_list(self, metric_prefix: str) -> List[str]:
158158
class RayUnittestBase(unittest.TestCase):
159159
@classmethod
160160
def setUpClass(cls):
161-
ray.init(ignore_reinit_error=True)
161+
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
162162

163163
@classmethod
164164
def tearDownClass(cls):

tests/trainer/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def setUp(self):
2727
self.config.model.model_path = get_model_path()
2828
self.config.explorer.rollout_model.engine_type = "vllm_async"
2929
self.config.algorithm.repeat_times = 3
30-
self.config.explorer.rollout_model.use_v1 = False
3130
self.config.project = "Trainer-unittest"
3231
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
3332
self.config.monitor.monitor_type = "tensorboard"
@@ -45,6 +44,7 @@ class TestTrainerCountdown(BaseTrainerCase):
4544
def test_trainer(self):
4645
"""Test the both and bench mode."""
4746
# test both mode
47+
self.config.explorer.rollout_model.use_v1 = False
4848
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
4949
self.config.buffer.explorer_input.eval_tasksets.append(
5050
get_unittest_dataset_config("countdown", "test")

trinity/buffer/queue.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from copy import deepcopy
44
from typing import List
55

6-
import ray
7-
86
from trinity.buffer.writer.file_writer import JSONWriter
97
from trinity.buffer.writer.sql_writer import SQLWriter
108
from trinity.common.config import BufferConfig, StorageConfig
@@ -20,7 +18,6 @@ def is_json_file(path: str) -> bool:
2018
return path.endswith(".json") or path.endswith(".jsonl")
2119

2220

23-
@ray.remote
2421
class QueueActor:
2522
"""An asyncio.Queue based queue actor."""
2623

trinity/buffer/ray_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
5555
ray.remote(cls)
5656
.options(
5757
name=f"sql-{storage_config.name}",
58+
namespace=ray.get_runtime_context().namespace,
5859
get_if_exists=True,
5960
)
6061
.remote(storage_config, config)
@@ -154,6 +155,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
154155
ray.remote(cls)
155156
.options(
156157
name=f"json-{storage_config.name}",
158+
namespace=ray.get_runtime_context().namespace,
157159
get_if_exists=True,
158160
)
159161
.remote(storage_config, config)

trinity/buffer/reader/queue_reader.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ class QueueReader(BufferReader):
1919
def __init__(self, storage_config: StorageConfig, config: BufferConfig):
2020
assert storage_config.storage_type == StorageType.QUEUE
2121
self.read_batch_size = config.read_batch_size
22-
self.queue = QueueActor.options(
23-
name=f"queue-{storage_config.name}",
24-
get_if_exists=True,
25-
).remote(storage_config, config)
22+
self.queue = (
23+
ray.remote(QueueActor)
24+
.options(
25+
name=f"queue-{storage_config.name}",
26+
namespace=ray.get_runtime_context().namespace,
27+
get_if_exists=True,
28+
)
29+
.remote(storage_config, config)
30+
)
2631

2732
def read(
2833
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None

trinity/buffer/writer/queue_writer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ class QueueWriter(BufferWriter):
1818
def __init__(self, meta: StorageConfig, config: BufferConfig):
1919
assert meta.storage_type == StorageType.QUEUE
2020
self.config = config
21-
self.queue = QueueActor.options(
22-
name=f"queue-{meta.name}",
23-
get_if_exists=True,
24-
).remote(meta, config)
21+
self.queue = (
22+
ray.remote(QueueActor)
23+
.options(
24+
name=f"queue-{meta.name}",
25+
namespace=ray.get_runtime_context().namespace,
26+
get_if_exists=True,
27+
)
28+
.remote(meta, config)
29+
)
2530

2631
def write(self, data: List) -> None:
2732
ray.get(self.queue.put_batch.remote(data))

trinity/cli/launcher.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020

2121
def bench(config: Config) -> None:
2222
"""Evaluate model."""
23-
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
23+
explorer = (
24+
ray.remote(Explorer)
25+
.options(
26+
name=EXPLORER_NAME,
27+
namespace=ray.get_runtime_context().namespace,
28+
)
29+
.remote(config)
30+
)
2431
try:
2532
ray.get(explorer.prepare.remote())
2633
ray.get(explorer.benchmark.remote())
@@ -34,7 +41,14 @@ def bench(config: Config) -> None:
3441
def explore(config: Config) -> None:
3542
"""Run explorer."""
3643
try:
37-
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
44+
explorer = (
45+
ray.remote(Explorer)
46+
.options(
47+
name=EXPLORER_NAME,
48+
namespace=ray.get_runtime_context().namespace,
49+
)
50+
.remote(config)
51+
)
3852
ray.get(explorer.prepare.remote())
3953
ray.get(explorer.sync_weight.remote())
4054
ray.get(explorer.explore.remote())
@@ -47,7 +61,14 @@ def explore(config: Config) -> None:
4761
def train(config: Config) -> None:
4862
"""Run trainer."""
4963
try:
50-
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
64+
trainer = (
65+
ray.remote(Trainer)
66+
.options(
67+
name=TRAINER_NAME,
68+
namespace=ray.get_runtime_context().namespace,
69+
)
70+
.remote(config)
71+
)
5172
ray.get(trainer.prepare.remote())
5273
ray.get(trainer.sync_weight.remote())
5374
ray.get(trainer.train.remote())
@@ -67,8 +88,23 @@ def both(config: Config) -> None:
6788
the latest step. The specific number of experiences may vary for different
6889
algorithms and tasks.
6990
"""
70-
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
71-
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
91+
namespace = ray.get_runtime_context().namespace
92+
explorer = (
93+
ray.remote(Explorer)
94+
.options(
95+
name=EXPLORER_NAME,
96+
namespace=namespace,
97+
)
98+
.remote(config)
99+
)
100+
trainer = (
101+
ray.remote(Trainer)
102+
.options(
103+
name=TRAINER_NAME,
104+
namespace=namespace,
105+
)
106+
.remote(config)
107+
)
72108
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
73109
ray.get(
74110
[
@@ -191,17 +227,16 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
191227
activate_data_module(
192228
f"{data_processor_config.data_processor_url}/experience_pipeline", config_path
193229
)
194-
ray_namespace = config.ray_namespace
195230
if dlc:
196231
from trinity.utils.dlc_utils import setup_ray_cluster
197232

198-
setup_ray_cluster(namespace=ray_namespace)
233+
setup_ray_cluster(namespace=config.ray_namespace)
199234
else:
200235
from trinity.utils.dlc_utils import is_running
201236

202237
if not is_running:
203238
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
204-
ray.init(namespace=ray_namespace, ignore_reinit_error=True)
239+
ray.init(namespace=config.ray_namespace, ignore_reinit_error=True)
205240
if config.mode == "explore":
206241
explore(config)
207242
elif config.mode == "train":
@@ -214,7 +249,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
214249
if dlc:
215250
from trinity.utils.dlc_utils import stop_ray_cluster
216251

217-
stop_ray_cluster()
252+
stop_ray_cluster(namespace=config.ray_namespace)
218253

219254

220255
def studio(port: int = 8501):

trinity/common/config.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ class InferenceModelConfig:
181181

182182
# ! DO NOT SET
183183
bundle_indices: str = ""
184-
ray_namespace: str = ""
185184

186185

187186
@dataclass
@@ -354,7 +353,7 @@ class Config:
354353
checkpoint_root_dir: str = ""
355354
# ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name`
356355
checkpoint_job_dir: str = ""
357-
# ! DO NOT SET, automatically generated as f"{config.project}-{config.name}"
356+
# If not set, automatically generated as f"{config.project}-{config.name}"
358357
ray_namespace: str = ""
359358

360359
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
@@ -579,7 +578,8 @@ def check_and_update(self) -> None: # noqa: C901
579578
self._check_deprecated()
580579

581580
# set namespace
582-
self.ray_namespace = f"{self.project}-{self.name}"
581+
if self.ray_namespace is None or len(self.ray_namespace) == 0:
582+
self.ray_namespace = f"{self.project}-{self.name}"
583583

584584
# check algorithm
585585
self._check_algorithm()
@@ -611,9 +611,6 @@ def check_and_update(self) -> None: # noqa: C901
611611
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
612612
if self.explorer.rollout_model.max_response_tokens is None:
613613
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
614-
self.explorer.rollout_model.ray_namespace = self.ray_namespace
615-
for model in self.explorer.auxiliary_models:
616-
model.ray_namespace = self.ray_namespace
617614

618615
# check synchronizer
619616
self.synchronizer.explorer_world_size = (

trinity/common/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def create_inference_models(
8989
for bundle_id, node_id in bundle_node_map.items():
9090
node_bundle_map[node_id].append(bundle_id)
9191
allocator = _BundleAllocator(node_bundle_map)
92-
92+
namespace = ray.get_runtime_context().namespace
9393
# create rollout models
9494
for _ in range(config.explorer.rollout_model.engine_num):
9595
bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size)
@@ -101,6 +101,7 @@ def create_inference_models(
101101
.options(
102102
num_cpus=0,
103103
num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1,
104+
namespace=namespace,
104105
scheduling_strategy=PlacementGroupSchedulingStrategy(
105106
placement_group=pg,
106107
placement_group_capture_child_tasks=True,
@@ -128,6 +129,7 @@ def create_inference_models(
128129
.options(
129130
num_cpus=0,
130131
num_gpus=0 if model_config.tensor_parallel_size > 1 else 1,
132+
namespace=namespace,
131133
scheduling_strategy=PlacementGroupSchedulingStrategy(
132134
placement_group=pg,
133135
placement_group_capture_child_tasks=True,

trinity/common/models/vllm_async_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Optional, Tuple, Union
99

1010
import aiohttp
11+
import ray
1112
import torch
1213
import vllm
1314
from vllm.sampling_params import RequestOutputKind
@@ -298,7 +299,7 @@ async def init_process_group(
298299
timeout,
299300
update_with_checkpoint,
300301
state_dict_meta,
301-
self.config.ray_namespace,
302+
ray.get_runtime_context().namespace,
302303
),
303304
)
304305

0 commit comments

Comments
 (0)