Skip to content

Commit 061407c

Browse files
committed
Merge branch 'main' into feat/exp_pipeline
2 parents 1430035 + e4052f8 commit 061407c

File tree

17 files changed

+145
-63
lines changed

17 files changed

+145
-63
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ Used to log training metrics during execution.
107107
```yaml
108108
monitor:
109109
monitor_type: wandb
110+
enable_ray_timeline: False
110111
```
111112

112113
- `monitor_type`: Type of monitoring system. Options:
113114
- `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
114115
- `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `<checkpoint_root_dir>/<project>/<name>/monitor/tensorboard`.
116+
- `enable_ray_timeline`: Whether to export the ray timeline. If set to `True`, a `timeline.json` file will be exported to `<checkpoint_root_dir>/<project>/<name>/monitor`. You can view the timeline file in Chrome at [chrome://tracing](chrome://tracing).
115117

116118
---
117119

tests/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def metric_list(self, metric_prefix: str) -> List[str]:
158158
class RayUnittestBase(unittest.TestCase):
159159
@classmethod
160160
def setUpClass(cls):
161-
ray.init(ignore_reinit_error=True)
161+
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
162162

163163
@classmethod
164164
def tearDownClass(cls):

tests/trainer/trainer_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def setUp(self):
2727
self.config.model.model_path = get_model_path()
2828
self.config.explorer.rollout_model.engine_type = "vllm_async"
2929
self.config.algorithm.repeat_times = 3
30-
self.config.explorer.rollout_model.use_v1 = False
3130
self.config.project = "Trainer-unittest"
3231
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
3332
self.config.monitor.monitor_type = "tensorboard"
@@ -45,6 +44,7 @@ class TestTrainerCountdown(BaseTrainerCase):
4544
def test_trainer(self):
4645
"""Test the both and bench mode."""
4746
# test both mode
47+
self.config.explorer.rollout_model.use_v1 = False
4848
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
4949
self.config.buffer.explorer_input.eval_tasksets.append(
5050
get_unittest_dataset_config("countdown", "test")
@@ -149,6 +149,7 @@ def test_trainer(self):
149149
response_metrics = parser.metric_list("response_length")
150150
self.assertTrue(len(response_metrics) > 0)
151151
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
152+
ray.timeline(filename="timeline.json")
152153
ray.shutdown(_exiting_interpreter=True)
153154
# check checkpoint
154155
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

trinity/buffer/queue.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from copy import deepcopy
44
from typing import List
55

6-
import ray
7-
86
from trinity.buffer.writer.file_writer import JSONWriter
97
from trinity.buffer.writer.sql_writer import SQLWriter
108
from trinity.common.config import BufferConfig, StorageConfig
@@ -20,7 +18,6 @@ def is_json_file(path: str) -> bool:
2018
return path.endswith(".json") or path.endswith(".jsonl")
2119

2220

23-
@ray.remote
2421
class QueueActor:
2522
"""An asyncio.Queue based queue actor."""
2623

trinity/buffer/ray_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
5555
ray.remote(cls)
5656
.options(
5757
name=f"sql-{storage_config.name}",
58+
namespace=ray.get_runtime_context().namespace,
5859
get_if_exists=True,
5960
)
6061
.remote(storage_config, config)
@@ -154,6 +155,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
154155
ray.remote(cls)
155156
.options(
156157
name=f"json-{storage_config.name}",
158+
namespace=ray.get_runtime_context().namespace,
157159
get_if_exists=True,
158160
)
159161
.remote(storage_config, config)

trinity/buffer/reader/queue_reader.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ class QueueReader(BufferReader):
1919
def __init__(self, storage_config: StorageConfig, config: BufferConfig):
2020
assert storage_config.storage_type == StorageType.QUEUE
2121
self.read_batch_size = config.read_batch_size
22-
self.queue = QueueActor.options(
23-
name=f"queue-{storage_config.name}",
24-
get_if_exists=True,
25-
).remote(storage_config, config)
22+
self.queue = (
23+
ray.remote(QueueActor)
24+
.options(
25+
name=f"queue-{storage_config.name}",
26+
namespace=ray.get_runtime_context().namespace,
27+
get_if_exists=True,
28+
)
29+
.remote(storage_config, config)
30+
)
2631

2732
def read(
2833
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None

trinity/buffer/writer/queue_writer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ class QueueWriter(BufferWriter):
1818
def __init__(self, meta: StorageConfig, config: BufferConfig):
1919
assert meta.storage_type == StorageType.QUEUE
2020
self.config = config
21-
self.queue = QueueActor.options(
22-
name=f"queue-{meta.name}",
23-
get_if_exists=True,
24-
).remote(meta, config)
21+
self.queue = (
22+
ray.remote(QueueActor)
23+
.options(
24+
name=f"queue-{meta.name}",
25+
namespace=ray.get_runtime_context().namespace,
26+
get_if_exists=True,
27+
)
28+
.remote(meta, config)
29+
)
2530

2631
def write(self, data: List) -> None:
2732
ray.get(self.queue.put_batch.remote(data))

trinity/cli/launcher.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020

2121
def bench(config: Config) -> None:
2222
"""Evaluate model."""
23-
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
23+
explorer = (
24+
ray.remote(Explorer)
25+
.options(
26+
name=EXPLORER_NAME,
27+
namespace=ray.get_runtime_context().namespace,
28+
)
29+
.remote(config)
30+
)
2431
try:
2532
ray.get(explorer.prepare.remote())
2633
ray.get(explorer.benchmark.remote())
@@ -34,7 +41,14 @@ def bench(config: Config) -> None:
3441
def explore(config: Config) -> None:
3542
"""Run explorer."""
3643
try:
37-
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
44+
explorer = (
45+
ray.remote(Explorer)
46+
.options(
47+
name=EXPLORER_NAME,
48+
namespace=ray.get_runtime_context().namespace,
49+
)
50+
.remote(config)
51+
)
3852
ray.get(explorer.prepare.remote())
3953
ray.get(explorer.sync_weight.remote())
4054
ray.get(explorer.explore.remote())
@@ -47,7 +61,14 @@ def explore(config: Config) -> None:
4761
def train(config: Config) -> None:
4862
"""Run trainer."""
4963
try:
50-
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
64+
trainer = (
65+
ray.remote(Trainer)
66+
.options(
67+
name=TRAINER_NAME,
68+
namespace=ray.get_runtime_context().namespace,
69+
)
70+
.remote(config)
71+
)
5172
ray.get(trainer.prepare.remote())
5273
ray.get(trainer.sync_weight.remote())
5374
ray.get(trainer.train.remote())
@@ -67,8 +88,23 @@ def both(config: Config) -> None:
6788
the latest step. The specific number of experiences may vary for different
6889
algorithms and tasks.
6990
"""
70-
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
71-
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
91+
namespace = ray.get_runtime_context().namespace
92+
explorer = (
93+
ray.remote(Explorer)
94+
.options(
95+
name=EXPLORER_NAME,
96+
namespace=namespace,
97+
)
98+
.remote(config)
99+
)
100+
trainer = (
101+
ray.remote(Trainer)
102+
.options(
103+
name=TRAINER_NAME,
104+
namespace=namespace,
105+
)
106+
.remote(config)
107+
)
72108
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
73109
ray.get(
74110
[
@@ -192,30 +228,36 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
192228
activate_data_module(
193229
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", config_path
194230
)
195-
ray_namespace = config.ray_namespace
196231
if dlc:
197232
from trinity.utils.dlc_utils import setup_ray_cluster
198233

199-
setup_ray_cluster(namespace=ray_namespace)
234+
setup_ray_cluster(namespace=config.ray_namespace)
200235
else:
201236
from trinity.utils.dlc_utils import is_running
202237

203238
if not is_running:
204239
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
205-
ray.init(namespace=ray_namespace, ignore_reinit_error=True)
206-
if config.mode == "explore":
207-
explore(config)
208-
elif config.mode == "train":
209-
train(config)
210-
elif config.mode == "both":
211-
both(config)
212-
elif config.mode == "bench":
213-
bench(config)
214-
215-
if dlc:
216-
from trinity.utils.dlc_utils import stop_ray_cluster
217-
218-
stop_ray_cluster()
240+
ray.init(namespace=config.ray_namespace, ignore_reinit_error=True)
241+
try:
242+
if config.mode == "explore":
243+
explore(config)
244+
elif config.mode == "train":
245+
train(config)
246+
elif config.mode == "both":
247+
both(config)
248+
elif config.mode == "bench":
249+
bench(config)
250+
finally:
251+
if config.monitor.enable_ray_timeline:
252+
timeline_file = os.path.join(config.monitor.cache_dir, "timeline.json")
253+
logger.info(f"Exporting Ray timeline to {timeline_file}...")
254+
ray.timeline(filename=timeline_file)
255+
logger.info("Done. You can open the timeline file in `chrome://tracing`")
256+
257+
if dlc:
258+
from trinity.utils.dlc_utils import stop_ray_cluster
259+
260+
stop_ray_cluster(namespace=config.ray_namespace)
219261

220262

221263
def studio(port: int = 8501):

trinity/common/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ class InferenceModelConfig:
192192

193193
# ! DO NOT SET
194194
bundle_indices: str = ""
195-
ray_namespace: str = ""
196195

197196

198197
@dataclass
@@ -331,6 +330,9 @@ class MonitorConfig:
331330
monitor_type: str = "tensorboard"
332331
# the default args for monitor
333332
monitor_args: Dict = field(default_factory=dict)
333+
# whether to enable ray timeline profile
334+
# the output file will be saved to `cache_dir/timeline.json`
335+
enable_ray_timeline: bool = False
334336
# ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor
335337
cache_dir: str = ""
336338

@@ -365,7 +367,7 @@ class Config:
365367
checkpoint_root_dir: str = ""
366368
# ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name`
367369
checkpoint_job_dir: str = ""
368-
# ! DO NOT SET, automatically generated as f"{config.project}-{config.name}"
370+
# If not set, automatically generated as f"{config.project}-{config.name}"
369371
ray_namespace: str = ""
370372

371373
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
@@ -590,7 +592,8 @@ def check_and_update(self) -> None: # noqa: C901
590592
self._check_deprecated()
591593

592594
# set namespace
593-
self.ray_namespace = f"{self.project}-{self.name}"
595+
if self.ray_namespace is None or len(self.ray_namespace) == 0:
596+
self.ray_namespace = f"{self.project}-{self.name}"
594597

595598
# check algorithm
596599
self._check_algorithm()
@@ -622,9 +625,6 @@ def check_and_update(self) -> None: # noqa: C901
622625
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
623626
if self.explorer.rollout_model.max_response_tokens is None:
624627
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
625-
self.explorer.rollout_model.ray_namespace = self.ray_namespace
626-
for model in self.explorer.auxiliary_models:
627-
model.ray_namespace = self.ray_namespace
628628

629629
# check synchronizer
630630
self.synchronizer.explorer_world_size = (

trinity/common/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def create_inference_models(
8989
for bundle_id, node_id in bundle_node_map.items():
9090
node_bundle_map[node_id].append(bundle_id)
9191
allocator = _BundleAllocator(node_bundle_map)
92-
92+
namespace = ray.get_runtime_context().namespace
9393
# create rollout models
9494
for _ in range(config.explorer.rollout_model.engine_num):
9595
bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size)
@@ -101,6 +101,7 @@ def create_inference_models(
101101
.options(
102102
num_cpus=0,
103103
num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1,
104+
namespace=namespace,
104105
scheduling_strategy=PlacementGroupSchedulingStrategy(
105106
placement_group=pg,
106107
placement_group_capture_child_tasks=True,
@@ -128,6 +129,7 @@ def create_inference_models(
128129
.options(
129130
num_cpus=0,
130131
num_gpus=0 if model_config.tensor_parallel_size > 1 else 1,
132+
namespace=namespace,
131133
scheduling_strategy=PlacementGroupSchedulingStrategy(
132134
placement_group=pg,
133135
placement_group_capture_child_tasks=True,

0 commit comments

Comments
 (0)