Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/sphinx_doc/assets/scripts-multi-plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ dev = [
"pytest>=8.0.0",
"pytest-json-ctrf",
"parameterized",
"matplotlib"
]

doc = [
Expand Down
63 changes: 63 additions & 0 deletions scripts/multi_exps_plot/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Multi-Run Experiment Plotting Script

## Description

Due to the stochastic nature of RFT results, multiple experimental runs are necessary to establish reliable average performance and confidence intervals. This script is designed to parse and plot the results from these repeated runs, enabling visual comparisons between different sets of experiments.

## Usage

***Before running this script***, ensure your experiment results are available. For example, after running the [grpo_gsm8k](https://github.com/modelscope/Trinity-RFT/blob/main/examples/grpo_gsm8k/gsm8k.yaml) script **three times**, the result directories will be located under a path pattern such as `/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-{1, 2, 3}`. The directory structure for a single run is expected to be as follows:

└── qwen2.5-1.5B-gsm8k-1
├── buffer
├── global_step_xxx
└── monitor
└── tensorboard
├── explorer
├── trainer
└── ...


***To run the script***, you need to configure the following key parameters in `plot_configs.yaml`:

```yaml
plot_configs:
# A list of all scalar keys to plot
scalar_keys:
- "eval/gsm8k-eval/accuracy/mean"
- "response_length/mean"
# - "critic/rewards/mean"

exps_configs:
# Define each experiment group to be plotted
gsm8k-train:
# 'paths' should list the root directories of each individual run
paths:
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-1"
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-2"
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-3"
# - "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-n"

# Optional: Color of the curve.
color: "blue"

# Define other experiment groups for comparison
math-train:
paths:
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-1"
# ...
color: "red"
```


Once the `YAML` file is configured, execute the following command to generate the plot:

```bash
python scripts/multi_exps_plot/multi_exps_plot.py --config scripts/multi_exps_plot/plot_configs.yaml
```

## Example

Below is an example of the output by this script. The experiment shows `Qwen2.5-1.5B-Instruct` RFT with `GRPO` on the `GSM8k` and `MATH` datasets, with performance evaluated on the `MATH500` benchmark.

![Example Plot of GRPO on GSM8k and MATH](../../docs/sphinx_doc/assets/scripts-multi-plot.png)
248 changes: 248 additions & 0 deletions scripts/multi_exps_plot/multi_exps_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import argparse
import glob
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from tensorboard.backend.event_processing import event_accumulator

from trinity.utils.log import get_logger

# Initialize logger
logger = get_logger(__name__)


def parse_args():
parser = argparse.ArgumentParser(description="Plot multi results from TensorBoard logs.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the YAML configuration file."
)
return parser.parse_args()


def load_config(config_path: str) -> dict:
if not os.path.exists(config_path):
logger.error(f"Configuration file not found: {config_path}")
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
try:
config = yaml.safe_load(f)
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file: {e}", exc_info=True)
raise ValueError(f"Error parsing YAML file: {e}")
return config


def find_scalars_in_event_file(event_file: str) -> list[str]:
"""Scans a single tfevents file and returns all scalar keys"""
try:
ea = event_accumulator.EventAccumulator(
event_file, size_guidance={event_accumulator.SCALARS: 0}
)
ea.Reload()
return ea.scalars.Keys()
except Exception as e:
logger.warning(f"Could not read scalars from event file '{event_file}': {e}")
return []


def build_scalar_location_map(base_path: str) -> dict[str, str]:
"""Find all scalars in 'explorer' and 'trainer'"""
scalar_map = {}
for folder in ["explorer", "trainer"]:
log_dir = os.path.join(base_path, "monitor", "tensorboard", folder)
if not os.path.isdir(log_dir):
continue

event_files = glob.glob(os.path.join(log_dir, "events.out.tfevents.*"))
if not event_files:
continue

# Use the first event file found in the directory
keys = find_scalars_in_event_file(event_files[0])
for key in keys:
if key in scalar_map:
logger.warning(
f"Duplicate scalar key '{key}' found. Using first one found ('{scalar_map[key]}')."
)
else:
scalar_map[key] = folder
return scalar_map


def find_tfevents_file(dir_path: str) -> str | None:
"""Finds a tfevents file within a specified directory"""
event_files = glob.glob(os.path.join(dir_path, "events.out.tfevents.*"))
if not event_files:
return None
if len(event_files) > 1:
latest_file = sorted(event_files)[-1]
logger.debug(
f"Multiple tfevents files found in '{dir_path}'. Using the latest one: {latest_file}"
)
return latest_file
return event_files[0]


def parse_tensorboard_log(log_dir: str, scalar_key: str) -> pd.Series:
"""Parses a single TensorBoard log directory to extract scalar data"""
try:
event_file = find_tfevents_file(log_dir)
if event_file is None:
raise FileNotFoundError(f"No tfevents file found in directory: '{log_dir}'")

ea = event_accumulator.EventAccumulator(
event_file, size_guidance={event_accumulator.SCALARS: 0}
)
ea.Reload()

if scalar_key not in ea.scalars.Keys():
logger.warning(f"Scalar key '{scalar_key}' not found in file '{event_file}'.")
return pd.Series(dtype=np.float64)

scalar_events = ea.scalars.Items(scalar_key)
steps = [e.step for e in scalar_events]
values = [e.value for e in scalar_events]

return pd.Series(data=values, index=steps, name=log_dir)

except Exception as e:
logger.error(f"Failed to parse directory '{log_dir}': {e}")
return pd.Series(dtype=np.float64)


def plot_confidence_interval(
experiments_data: dict, title: str, x_label: str, y_label: str, output_filename: str
):
"""Plots the mean and confidence interval for multiple experiments"""
plt.style.use("seaborn-v0_8-whitegrid")
fig, ax = plt.subplots(figsize=(12, 7))
color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]

for i, (exp_name, exp_details) in enumerate(experiments_data.items()):
all_runs_data = exp_details["data"]
color = exp_details.get("color") or color_cycle[i % len(color_cycle)]

if not all_runs_data:
logger.warning(f"No valid data for experiment '{exp_name}' on this plot. Skipping.")
continue

df = pd.concat(all_runs_data, axis=1)
mean_values = df.mean(axis=1).sort_index()
std_values = df.std(axis=1).sort_index()
steps = mean_values.index.values

ax.plot(
steps, mean_values, label=exp_name, color=color, marker="o", markersize=4, linestyle="-"
)
ax.fill_between(
steps, mean_values - std_values, mean_values + std_values, color=color, alpha=0.2
)

ax.set_title(title, fontsize=16, pad=20)
ax.set_xlabel(x_label, fontsize=12)
ax.set_ylabel(y_label, fontsize=12)
ax.legend(loc="best", fontsize=12)
ax.tick_params(axis="both", which="major", labelsize=10)

output_dir = os.path.dirname(output_filename)
if output_dir:
os.makedirs(output_dir, exist_ok=True)

plt.tight_layout()
plt.savefig(output_filename, dpi=300)
logger.info(f"Chart successfully saved to '{output_filename}'")
plt.close(fig)


def main():
args = parse_args()
config = load_config(args.config)
logger.info(f"Successfully loaded configuration from: {args.config}")

# Extract settings
plot_cfg = config.get("plot_configs", {})
exps_cfg = config.get("exps_configs", {})

output_path = plot_cfg.get("output_path", "./plots")
scalar_keys_to_plot = plot_cfg.get("scalar_keys", [])

if not scalar_keys_to_plot:
logger.warning("No 'scalar_keys' specified in 'plot_configs'.")
return

# Build scalar location maps for each experiment group
scalar_maps = {}
for exp_name, exp_details in exps_cfg.items():
logger.info(f"Scanning for scalars in experiment group: {exp_name}")
for path in exp_details.get("paths", []):
if os.path.isdir(path):
scalar_maps[exp_name] = build_scalar_location_map(path)
if scalar_maps[exp_name]:
logger.info(
f"Scalar map for '{exp_name}' created successfully from path: {path}"
)
break
if exp_name not in scalar_maps:
logger.warning(
f"Could not create a scalar map for '{exp_name}'. All paths might be invalid."
)
scalar_maps[exp_name] = {}

# Main Loop: Generate one plot for each specified scalar key
for scalar_key in scalar_keys_to_plot:
logger.info(f"\n--- Generating plot for scalar_key: '{scalar_key}' ---")
experiments_data_for_this_plot = {}

for exp_name, exp_details in exps_cfg.items():
scalar_map = scalar_maps.get(exp_name, {})
if scalar_key not in scalar_map:
logger.warning(
f"Scalar '{scalar_key}' not found for experiment '{exp_name}'. Skipping this curve."
)
continue

target_folder = scalar_map[scalar_key]
logger.info(
f"Processing '{exp_name}': Found '{scalar_key}' in '{target_folder}' folder."
)

all_runs_data = []
for path in exp_details.get("paths", []):
log_dir = os.path.join(path, "monitor", "tensorboard", target_folder)
if os.path.isdir(log_dir):
run_data = parse_tensorboard_log(log_dir, scalar_key)
if not run_data.empty:
all_runs_data.append(run_data)
else:
logger.warning(f"Log directory not found for path: {log_dir}")

experiments_data_for_this_plot[exp_name] = {
"data": all_runs_data,
"color": exp_details.get("color"),
}

# Generate dynamic plot details
clean_scalar_name = re.sub(r"[^a-zA-Z0-9_-]", "_", scalar_key)
output_filename = os.path.join(output_path, f"{clean_scalar_name}.png")

# Use templates for titles and labels if available
title = plot_cfg.get("title", "{scalar_key}").format(scalar_key=scalar_key)
x_label = plot_cfg.get("x_label", "Step")
y_label = plot_cfg.get("y_label_template", "{scalar_key}").format(scalar_key=scalar_key)

plot_confidence_interval(
experiments_data=experiments_data_for_this_plot,
title=title,
x_label=x_label,
y_label=y_label,
output_filename=output_filename,
)
logger.info("\nAll plots generated successfully.")


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions scripts/multi_exps_plot/plot_configs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# An example: comparison between training on gsm8k and math

# General configurations for plotting
plot_configs:
title: "Multi-exps Comparison for {scalar_key}"
x_label: "Steps"
y_label_template: "{scalar_key}"
output_path: "scripts/multi_exps_plot/output"

# A list of all scalar keys to plot
scalar_keys:
- "eval/gsm8k-eval/accuracy/mean"
- "response_length/mean"
# - "critic/rewards/mean"

# Configurations for each experiment to be plotted
exps_configs:
# Define each experiments' name
gsm8k-train:
# 'paths' should point to the root directory of each run
paths:
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-1"
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-2"
- "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-3"
# - "/PATH/TO/CHECKPOINT/Trinity-RFT-gsm8k/qwen2.5-1.5B-gsm8k-n"

# If not provided, a default color will be used
color: "blue"

math-train:
paths:
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-1"
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-2"
- "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-3"
# - "/PATH/TO/CHECKPOINT/Trinity-RFT-math/qwen2.5-1.5B-math-n"
color: "red"