Skip to content

Commit e4052f8

Browse files
authored
Add ray timeline for profiling (#98)
1 parent 8e4ce25 commit e4052f8

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ Used to log training metrics during execution.
107107
```yaml
108108
monitor:
109109
monitor_type: wandb
110+
enable_ray_timeline: False
110111
```
111112

112113
- `monitor_type`: Type of monitoring system. Options:
113114
- `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
114115
- `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `<checkpoint_root_dir>/<project>/<name>/monitor/tensorboard`.
116+
- `enable_ray_timeline`: Whether to export the ray timeline. If set to `True`, a `timeline.json` file will be exported to `<checkpoint_root_dir>/<project>/<name>/monitor`. You can view the timeline file in Chrome at [chrome://tracing](chrome://tracing).
115117

116118
---
117119

tests/trainer/trainer_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def test_trainer(self):
149149
response_metrics = parser.metric_list("response_length")
150150
self.assertTrue(len(response_metrics) > 0)
151151
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
152+
ray.timeline(filename="timeline.json")
152153
ray.shutdown(_exiting_interpreter=True)
153154
# check checkpoint
154155
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

trinity/cli/launcher.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,26 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
237237
if not is_running:
238238
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
239239
ray.init(namespace=config.ray_namespace, ignore_reinit_error=True)
240-
if config.mode == "explore":
241-
explore(config)
242-
elif config.mode == "train":
243-
train(config)
244-
elif config.mode == "both":
245-
both(config)
246-
elif config.mode == "bench":
247-
bench(config)
248-
249-
if dlc:
250-
from trinity.utils.dlc_utils import stop_ray_cluster
251-
252-
stop_ray_cluster(namespace=config.ray_namespace)
240+
try:
241+
if config.mode == "explore":
242+
explore(config)
243+
elif config.mode == "train":
244+
train(config)
245+
elif config.mode == "both":
246+
both(config)
247+
elif config.mode == "bench":
248+
bench(config)
249+
finally:
250+
if config.monitor.enable_ray_timeline:
251+
timeline_file = os.path.join(config.monitor.cache_dir, "timeline.json")
252+
logger.info(f"Exporting Ray timeline to {timeline_file}...")
253+
ray.timeline(filename=timeline_file)
254+
logger.info("Done. You can open the timeline file in `chrome://tracing`")
255+
256+
if dlc:
257+
from trinity.utils.dlc_utils import stop_ray_cluster
258+
259+
stop_ray_cluster(namespace=config.ray_namespace)
253260

254261

255262
def studio(port: int = 8501):

trinity/common/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ class MonitorConfig:
319319
monitor_type: str = "tensorboard"
320320
# the default args for monitor
321321
monitor_args: Dict = field(default_factory=dict)
322+
# whether to enable ray timeline profile
323+
# the output file will be saved to `cache_dir/timeline.json`
324+
enable_ray_timeline: bool = False
322325
# ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor
323326
cache_dir: str = ""
324327

0 commit comments

Comments
 (0)