Skip to content

Commit 675ff5b

Browse files
authored
Add a script to plot multi-run experiment results (#122)
1 parent 6b109b0 commit 675ff5b

File tree

5 files changed

+348
-0
lines changed

5 files changed

+348
-0
lines changed
218 KB
Loading

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ dev = [
6666
"pytest>=8.0.0",
6767
"pytest-json-ctrf",
6868
"parameterized",
69+
"matplotlib"
6970
]
7071

7172
doc = [

scripts/multi_exps_plot/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Multi-Run Experiment Plotting Script
2+
3+
## Description
4+
5+
Due to the stochastic nature of RFT results, multiple experimental runs are necessary to establish reliable average performance and confidence intervals. This script is designed to parse and plot the results from these repeated runs, enabling visual comparisons between different sets of experiments.
6+
7+
## Usage
8+
9+
***Before running this script***, ensure your experiment results are available. For example, after running the [grpo_gsm8k](https://github.com/modelscope/Trinity-RFT/blob/main/examples/grpo_gsm8k/gsm8k.yaml) script **three times**, the result directories will be located under a path pattern such as `/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-{1, 2, 3}`. The directory structure for a single run is expected to be as follows:
10+
11+
└── qwen2.5-1.5B-gsm8k-1
12+
├── buffer
13+
├── global_step_xxx
14+
└── monitor
15+
└── tensorboard
16+
├── explorer
17+
├── trainer
18+
└── ...
19+
20+
21+
***To run the script***, you need to configure the following key parameters in `plot_configs.yaml`:
22+
23+
```yaml
24+
plot_configs:
25+
# A list of all scalar keys to plot
26+
scalar_keys:
27+
- "eval/gsm8k-eval/accuracy/mean"
28+
- "response_length/mean"
29+
# - "critic/rewards/mean"
30+
31+
exps_configs:
32+
# Define each experiment group to be plotted
33+
gsm8k-train:
34+
# 'paths' should list the root directories of each individual run
35+
paths:
36+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-1"
37+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-2"
38+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-3"
39+
# - "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-n"
40+
41+
# Optional: Color of the curve.
42+
color: "blue"
43+
44+
# Define other experiment groups for comparison
45+
math-train:
46+
paths:
47+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-1"
48+
# ...
49+
color: "red"
50+
```
51+
52+
53+
Once the `YAML` file is configured, execute the following command to generate the plot:
54+
55+
```bash
56+
python scripts/multi_exps_plot/multi_exps_plot.py --config scripts/multi_exps_plot/plot_configs.yaml
57+
```
58+
59+
## Example
60+
61+
Below is an example of the output by this script. The experiment shows `Qwen2.5-1.5B-Instruct` RFT with `GRPO` on the `GSM8k` and `MATH` datasets, with performance evaluated on the `MATH500` benchmark.
62+
63+
![Example Plot of GRPO on GSM8k and MATH](../../docs/sphinx_doc/assets/scripts-multi-plot.png)
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import argparse
2+
import glob
3+
import os
4+
import re
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
import yaml
10+
from tensorboard.backend.event_processing import event_accumulator
11+
12+
from trinity.utils.log import get_logger
13+
14+
# Initialize logger
15+
logger = get_logger(__name__)
16+
17+
18+
def parse_args():
19+
parser = argparse.ArgumentParser(description="Plot multi results from TensorBoard logs.")
20+
parser.add_argument(
21+
"--config", type=str, required=True, help="Path to the YAML configuration file."
22+
)
23+
return parser.parse_args()
24+
25+
26+
def load_config(config_path: str) -> dict:
27+
if not os.path.exists(config_path):
28+
logger.error(f"Configuration file not found: {config_path}")
29+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
30+
with open(config_path, "r", encoding="utf-8") as f:
31+
try:
32+
config = yaml.safe_load(f)
33+
except yaml.YAMLError as e:
34+
logger.error(f"Error parsing YAML file: {e}", exc_info=True)
35+
raise ValueError(f"Error parsing YAML file: {e}")
36+
return config
37+
38+
39+
def find_scalars_in_event_file(event_file: str) -> list[str]:
40+
"""Scans a single tfevents file and returns all scalar keys"""
41+
try:
42+
ea = event_accumulator.EventAccumulator(
43+
event_file, size_guidance={event_accumulator.SCALARS: 0}
44+
)
45+
ea.Reload()
46+
return ea.scalars.Keys()
47+
except Exception as e:
48+
logger.warning(f"Could not read scalars from event file '{event_file}': {e}")
49+
return []
50+
51+
52+
def build_scalar_location_map(base_path: str) -> dict[str, str]:
53+
"""Find all scalars in 'explorer' and 'trainer'"""
54+
scalar_map = {}
55+
for folder in ["explorer", "trainer"]:
56+
log_dir = os.path.join(base_path, "monitor", "tensorboard", folder)
57+
if not os.path.isdir(log_dir):
58+
continue
59+
60+
event_files = glob.glob(os.path.join(log_dir, "events.out.tfevents.*"))
61+
if not event_files:
62+
continue
63+
64+
# Use the first event file found in the directory
65+
keys = find_scalars_in_event_file(event_files[0])
66+
for key in keys:
67+
if key in scalar_map:
68+
logger.warning(
69+
f"Duplicate scalar key '{key}' found. Using first one found ('{scalar_map[key]}')."
70+
)
71+
else:
72+
scalar_map[key] = folder
73+
return scalar_map
74+
75+
76+
def find_tfevents_file(dir_path: str) -> str | None:
77+
"""Finds a tfevents file within a specified directory"""
78+
event_files = glob.glob(os.path.join(dir_path, "events.out.tfevents.*"))
79+
if not event_files:
80+
return None
81+
if len(event_files) > 1:
82+
latest_file = sorted(event_files)[-1]
83+
logger.debug(
84+
f"Multiple tfevents files found in '{dir_path}'. Using the latest one: {latest_file}"
85+
)
86+
return latest_file
87+
return event_files[0]
88+
89+
90+
def parse_tensorboard_log(log_dir: str, scalar_key: str) -> pd.Series:
91+
"""Parses a single TensorBoard log directory to extract scalar data"""
92+
try:
93+
event_file = find_tfevents_file(log_dir)
94+
if event_file is None:
95+
raise FileNotFoundError(f"No tfevents file found in directory: '{log_dir}'")
96+
97+
ea = event_accumulator.EventAccumulator(
98+
event_file, size_guidance={event_accumulator.SCALARS: 0}
99+
)
100+
ea.Reload()
101+
102+
if scalar_key not in ea.scalars.Keys():
103+
logger.warning(f"Scalar key '{scalar_key}' not found in file '{event_file}'.")
104+
return pd.Series(dtype=np.float64)
105+
106+
scalar_events = ea.scalars.Items(scalar_key)
107+
steps = [e.step for e in scalar_events]
108+
values = [e.value for e in scalar_events]
109+
110+
return pd.Series(data=values, index=steps, name=log_dir)
111+
112+
except Exception as e:
113+
logger.error(f"Failed to parse directory '{log_dir}': {e}")
114+
return pd.Series(dtype=np.float64)
115+
116+
117+
def plot_confidence_interval(
118+
experiments_data: dict, title: str, x_label: str, y_label: str, output_filename: str
119+
):
120+
"""Plots the mean and confidence interval for multiple experiments"""
121+
plt.style.use("seaborn-v0_8-whitegrid")
122+
fig, ax = plt.subplots(figsize=(12, 7))
123+
color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
124+
125+
for i, (exp_name, exp_details) in enumerate(experiments_data.items()):
126+
all_runs_data = exp_details["data"]
127+
color = exp_details.get("color") or color_cycle[i % len(color_cycle)]
128+
129+
if not all_runs_data:
130+
logger.warning(f"No valid data for experiment '{exp_name}' on this plot. Skipping.")
131+
continue
132+
133+
df = pd.concat(all_runs_data, axis=1)
134+
mean_values = df.mean(axis=1).sort_index()
135+
std_values = df.std(axis=1).sort_index()
136+
steps = mean_values.index.values
137+
138+
ax.plot(
139+
steps, mean_values, label=exp_name, color=color, marker="o", markersize=4, linestyle="-"
140+
)
141+
ax.fill_between(
142+
steps, mean_values - std_values, mean_values + std_values, color=color, alpha=0.2
143+
)
144+
145+
ax.set_title(title, fontsize=16, pad=20)
146+
ax.set_xlabel(x_label, fontsize=12)
147+
ax.set_ylabel(y_label, fontsize=12)
148+
ax.legend(loc="best", fontsize=12)
149+
ax.tick_params(axis="both", which="major", labelsize=10)
150+
151+
output_dir = os.path.dirname(output_filename)
152+
if output_dir:
153+
os.makedirs(output_dir, exist_ok=True)
154+
155+
plt.tight_layout()
156+
plt.savefig(output_filename, dpi=300)
157+
logger.info(f"Chart successfully saved to '{output_filename}'")
158+
plt.close(fig)
159+
160+
161+
def main():
162+
args = parse_args()
163+
config = load_config(args.config)
164+
logger.info(f"Successfully loaded configuration from: {args.config}")
165+
166+
# Extract settings
167+
plot_cfg = config.get("plot_configs", {})
168+
exps_cfg = config.get("exps_configs", {})
169+
170+
output_path = plot_cfg.get("output_path", "./plots")
171+
scalar_keys_to_plot = plot_cfg.get("scalar_keys", [])
172+
173+
if not scalar_keys_to_plot:
174+
logger.warning("No 'scalar_keys' specified in 'plot_configs'.")
175+
return
176+
177+
# Build scalar location maps for each experiment group
178+
scalar_maps = {}
179+
for exp_name, exp_details in exps_cfg.items():
180+
logger.info(f"Scanning for scalars in experiment group: {exp_name}")
181+
for path in exp_details.get("paths", []):
182+
if os.path.isdir(path):
183+
scalar_maps[exp_name] = build_scalar_location_map(path)
184+
if scalar_maps[exp_name]:
185+
logger.info(
186+
f"Scalar map for '{exp_name}' created successfully from path: {path}"
187+
)
188+
break
189+
if exp_name not in scalar_maps:
190+
logger.warning(
191+
f"Could not create a scalar map for '{exp_name}'. All paths might be invalid."
192+
)
193+
scalar_maps[exp_name] = {}
194+
195+
# Main Loop: Generate one plot for each specified scalar key
196+
for scalar_key in scalar_keys_to_plot:
197+
logger.info(f"\n--- Generating plot for scalar_key: '{scalar_key}' ---")
198+
experiments_data_for_this_plot = {}
199+
200+
for exp_name, exp_details in exps_cfg.items():
201+
scalar_map = scalar_maps.get(exp_name, {})
202+
if scalar_key not in scalar_map:
203+
logger.warning(
204+
f"Scalar '{scalar_key}' not found for experiment '{exp_name}'. Skipping this curve."
205+
)
206+
continue
207+
208+
target_folder = scalar_map[scalar_key]
209+
logger.info(
210+
f"Processing '{exp_name}': Found '{scalar_key}' in '{target_folder}' folder."
211+
)
212+
213+
all_runs_data = []
214+
for path in exp_details.get("paths", []):
215+
log_dir = os.path.join(path, "monitor", "tensorboard", target_folder)
216+
if os.path.isdir(log_dir):
217+
run_data = parse_tensorboard_log(log_dir, scalar_key)
218+
if not run_data.empty:
219+
all_runs_data.append(run_data)
220+
else:
221+
logger.warning(f"Log directory not found for path: {log_dir}")
222+
223+
experiments_data_for_this_plot[exp_name] = {
224+
"data": all_runs_data,
225+
"color": exp_details.get("color"),
226+
}
227+
228+
# Generate dynamic plot details
229+
clean_scalar_name = re.sub(r"[^a-zA-Z0-9_-]", "_", scalar_key)
230+
output_filename = os.path.join(output_path, f"{clean_scalar_name}.png")
231+
232+
# Use templates for titles and labels if available
233+
title = plot_cfg.get("title", "{scalar_key}").format(scalar_key=scalar_key)
234+
x_label = plot_cfg.get("x_label", "Step")
235+
y_label = plot_cfg.get("y_label_template", "{scalar_key}").format(scalar_key=scalar_key)
236+
237+
plot_confidence_interval(
238+
experiments_data=experiments_data_for_this_plot,
239+
title=title,
240+
x_label=x_label,
241+
y_label=y_label,
242+
output_filename=output_filename,
243+
)
244+
logger.info("\nAll plots generated successfully.")
245+
246+
247+
if __name__ == "__main__":
248+
main()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# An example: comparison between training on gsm8k and math
2+
3+
# General configurations for plotting
4+
plot_configs:
5+
title: "Multi-exps Comparison for {scalar_key}"
6+
x_label: "Steps"
7+
y_label_template: "{scalar_key}"
8+
output_path: "scripts/multi_exps_plot/output"
9+
10+
# A list of all scalar keys to plot
11+
scalar_keys:
12+
- "eval/gsm8k-eval/accuracy/mean"
13+
- "response_length/mean"
14+
# - "critic/rewards/mean"
15+
16+
# Configurations for each experiment to be plotted
17+
exps_configs:
18+
# Define each experiments' name
19+
gsm8k-train:
20+
# 'paths' should point to the root directory of each run
21+
paths:
22+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-1"
23+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-2"
24+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-3"
25+
# - "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-n"
26+
27+
# If not provided, a default color will be used
28+
color: "blue"
29+
30+
math-train:
31+
paths:
32+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-1"
33+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-2"
34+
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-3"
35+
# - "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-n"
36+
color: "red"

0 commit comments

Comments
 (0)