|
| 1 | +import argparse |
| 2 | +import glob |
| 3 | +import os |
| 4 | +import re |
| 5 | + |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import yaml |
| 10 | +from tensorboard.backend.event_processing import event_accumulator |
| 11 | + |
| 12 | +from trinity.utils.log import get_logger |
| 13 | + |
| 14 | +# Initialize logger |
| 15 | +logger = get_logger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +def parse_args(): |
| 19 | + parser = argparse.ArgumentParser(description="Plot multi results from TensorBoard logs.") |
| 20 | + parser.add_argument( |
| 21 | + "--config", type=str, required=True, help="Path to the YAML configuration file." |
| 22 | + ) |
| 23 | + return parser.parse_args() |
| 24 | + |
| 25 | + |
| 26 | +def load_config(config_path: str) -> dict: |
| 27 | + if not os.path.exists(config_path): |
| 28 | + logger.error(f"Configuration file not found: {config_path}") |
| 29 | + raise FileNotFoundError(f"Configuration file not found: {config_path}") |
| 30 | + with open(config_path, "r", encoding="utf-8") as f: |
| 31 | + try: |
| 32 | + config = yaml.safe_load(f) |
| 33 | + except yaml.YAMLError as e: |
| 34 | + logger.error(f"Error parsing YAML file: {e}", exc_info=True) |
| 35 | + raise ValueError(f"Error parsing YAML file: {e}") |
| 36 | + return config |
| 37 | + |
| 38 | + |
| 39 | +def find_scalars_in_event_file(event_file: str) -> list[str]: |
| 40 | + """Scans a single tfevents file and returns all scalar keys""" |
| 41 | + try: |
| 42 | + ea = event_accumulator.EventAccumulator( |
| 43 | + event_file, size_guidance={event_accumulator.SCALARS: 0} |
| 44 | + ) |
| 45 | + ea.Reload() |
| 46 | + return ea.scalars.Keys() |
| 47 | + except Exception as e: |
| 48 | + logger.warning(f"Could not read scalars from event file '{event_file}': {e}") |
| 49 | + return [] |
| 50 | + |
| 51 | + |
| 52 | +def build_scalar_location_map(base_path: str) -> dict[str, str]: |
| 53 | + """Find all scalars in 'explorer' and 'trainer'""" |
| 54 | + scalar_map = {} |
| 55 | + for folder in ["explorer", "trainer"]: |
| 56 | + log_dir = os.path.join(base_path, "monitor", "tensorboard", folder) |
| 57 | + if not os.path.isdir(log_dir): |
| 58 | + continue |
| 59 | + |
| 60 | + event_files = glob.glob(os.path.join(log_dir, "events.out.tfevents.*")) |
| 61 | + if not event_files: |
| 62 | + continue |
| 63 | + |
| 64 | + # Use the first event file found in the directory |
| 65 | + keys = find_scalars_in_event_file(event_files[0]) |
| 66 | + for key in keys: |
| 67 | + if key in scalar_map: |
| 68 | + logger.warning( |
| 69 | + f"Duplicate scalar key '{key}' found. Using first one found ('{scalar_map[key]}')." |
| 70 | + ) |
| 71 | + else: |
| 72 | + scalar_map[key] = folder |
| 73 | + return scalar_map |
| 74 | + |
| 75 | + |
| 76 | +def find_tfevents_file(dir_path: str) -> str | None: |
| 77 | + """Finds a tfevents file within a specified directory""" |
| 78 | + event_files = glob.glob(os.path.join(dir_path, "events.out.tfevents.*")) |
| 79 | + if not event_files: |
| 80 | + return None |
| 81 | + if len(event_files) > 1: |
| 82 | + latest_file = sorted(event_files)[-1] |
| 83 | + logger.debug( |
| 84 | + f"Multiple tfevents files found in '{dir_path}'. Using the latest one: {latest_file}" |
| 85 | + ) |
| 86 | + return latest_file |
| 87 | + return event_files[0] |
| 88 | + |
| 89 | + |
| 90 | +def parse_tensorboard_log(log_dir: str, scalar_key: str) -> pd.Series: |
| 91 | + """Parses a single TensorBoard log directory to extract scalar data""" |
| 92 | + try: |
| 93 | + event_file = find_tfevents_file(log_dir) |
| 94 | + if event_file is None: |
| 95 | + raise FileNotFoundError(f"No tfevents file found in directory: '{log_dir}'") |
| 96 | + |
| 97 | + ea = event_accumulator.EventAccumulator( |
| 98 | + event_file, size_guidance={event_accumulator.SCALARS: 0} |
| 99 | + ) |
| 100 | + ea.Reload() |
| 101 | + |
| 102 | + if scalar_key not in ea.scalars.Keys(): |
| 103 | + logger.warning(f"Scalar key '{scalar_key}' not found in file '{event_file}'.") |
| 104 | + return pd.Series(dtype=np.float64) |
| 105 | + |
| 106 | + scalar_events = ea.scalars.Items(scalar_key) |
| 107 | + steps = [e.step for e in scalar_events] |
| 108 | + values = [e.value for e in scalar_events] |
| 109 | + |
| 110 | + return pd.Series(data=values, index=steps, name=log_dir) |
| 111 | + |
| 112 | + except Exception as e: |
| 113 | + logger.error(f"Failed to parse directory '{log_dir}': {e}") |
| 114 | + return pd.Series(dtype=np.float64) |
| 115 | + |
| 116 | + |
| 117 | +def plot_confidence_interval( |
| 118 | + experiments_data: dict, title: str, x_label: str, y_label: str, output_filename: str |
| 119 | +): |
| 120 | + """Plots the mean and confidence interval for multiple experiments""" |
| 121 | + plt.style.use("seaborn-v0_8-whitegrid") |
| 122 | + fig, ax = plt.subplots(figsize=(12, 7)) |
| 123 | + color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"] |
| 124 | + |
| 125 | + for i, (exp_name, exp_details) in enumerate(experiments_data.items()): |
| 126 | + all_runs_data = exp_details["data"] |
| 127 | + color = exp_details.get("color") or color_cycle[i % len(color_cycle)] |
| 128 | + |
| 129 | + if not all_runs_data: |
| 130 | + logger.warning(f"No valid data for experiment '{exp_name}' on this plot. Skipping.") |
| 131 | + continue |
| 132 | + |
| 133 | + df = pd.concat(all_runs_data, axis=1) |
| 134 | + mean_values = df.mean(axis=1).sort_index() |
| 135 | + std_values = df.std(axis=1).sort_index() |
| 136 | + steps = mean_values.index.values |
| 137 | + |
| 138 | + ax.plot( |
| 139 | + steps, mean_values, label=exp_name, color=color, marker="o", markersize=4, linestyle="-" |
| 140 | + ) |
| 141 | + ax.fill_between( |
| 142 | + steps, mean_values - std_values, mean_values + std_values, color=color, alpha=0.2 |
| 143 | + ) |
| 144 | + |
| 145 | + ax.set_title(title, fontsize=16, pad=20) |
| 146 | + ax.set_xlabel(x_label, fontsize=12) |
| 147 | + ax.set_ylabel(y_label, fontsize=12) |
| 148 | + ax.legend(loc="best", fontsize=12) |
| 149 | + ax.tick_params(axis="both", which="major", labelsize=10) |
| 150 | + |
| 151 | + output_dir = os.path.dirname(output_filename) |
| 152 | + if output_dir: |
| 153 | + os.makedirs(output_dir, exist_ok=True) |
| 154 | + |
| 155 | + plt.tight_layout() |
| 156 | + plt.savefig(output_filename, dpi=300) |
| 157 | + logger.info(f"Chart successfully saved to '{output_filename}'") |
| 158 | + plt.close(fig) |
| 159 | + |
| 160 | + |
| 161 | +def main(): |
| 162 | + args = parse_args() |
| 163 | + config = load_config(args.config) |
| 164 | + logger.info(f"Successfully loaded configuration from: {args.config}") |
| 165 | + |
| 166 | + # Extract settings |
| 167 | + plot_cfg = config.get("plot_configs", {}) |
| 168 | + exps_cfg = config.get("exps_configs", {}) |
| 169 | + |
| 170 | + output_path = plot_cfg.get("output_path", "./plots") |
| 171 | + scalar_keys_to_plot = plot_cfg.get("scalar_keys", []) |
| 172 | + |
| 173 | + if not scalar_keys_to_plot: |
| 174 | + logger.warning("No 'scalar_keys' specified in 'plot_configs'.") |
| 175 | + return |
| 176 | + |
| 177 | + # Build scalar location maps for each experiment group |
| 178 | + scalar_maps = {} |
| 179 | + for exp_name, exp_details in exps_cfg.items(): |
| 180 | + logger.info(f"Scanning for scalars in experiment group: {exp_name}") |
| 181 | + for path in exp_details.get("paths", []): |
| 182 | + if os.path.isdir(path): |
| 183 | + scalar_maps[exp_name] = build_scalar_location_map(path) |
| 184 | + if scalar_maps[exp_name]: |
| 185 | + logger.info( |
| 186 | + f"Scalar map for '{exp_name}' created successfully from path: {path}" |
| 187 | + ) |
| 188 | + break |
| 189 | + if exp_name not in scalar_maps: |
| 190 | + logger.warning( |
| 191 | + f"Could not create a scalar map for '{exp_name}'. All paths might be invalid." |
| 192 | + ) |
| 193 | + scalar_maps[exp_name] = {} |
| 194 | + |
| 195 | + # Main Loop: Generate one plot for each specified scalar key |
| 196 | + for scalar_key in scalar_keys_to_plot: |
| 197 | + logger.info(f"\n--- Generating plot for scalar_key: '{scalar_key}' ---") |
| 198 | + experiments_data_for_this_plot = {} |
| 199 | + |
| 200 | + for exp_name, exp_details in exps_cfg.items(): |
| 201 | + scalar_map = scalar_maps.get(exp_name, {}) |
| 202 | + if scalar_key not in scalar_map: |
| 203 | + logger.warning( |
| 204 | + f"Scalar '{scalar_key}' not found for experiment '{exp_name}'. Skipping this curve." |
| 205 | + ) |
| 206 | + continue |
| 207 | + |
| 208 | + target_folder = scalar_map[scalar_key] |
| 209 | + logger.info( |
| 210 | + f"Processing '{exp_name}': Found '{scalar_key}' in '{target_folder}' folder." |
| 211 | + ) |
| 212 | + |
| 213 | + all_runs_data = [] |
| 214 | + for path in exp_details.get("paths", []): |
| 215 | + log_dir = os.path.join(path, "monitor", "tensorboard", target_folder) |
| 216 | + if os.path.isdir(log_dir): |
| 217 | + run_data = parse_tensorboard_log(log_dir, scalar_key) |
| 218 | + if not run_data.empty: |
| 219 | + all_runs_data.append(run_data) |
| 220 | + else: |
| 221 | + logger.warning(f"Log directory not found for path: {log_dir}") |
| 222 | + |
| 223 | + experiments_data_for_this_plot[exp_name] = { |
| 224 | + "data": all_runs_data, |
| 225 | + "color": exp_details.get("color"), |
| 226 | + } |
| 227 | + |
| 228 | + # Generate dynamic plot details |
| 229 | + clean_scalar_name = re.sub(r"[^a-zA-Z0-9_-]", "_", scalar_key) |
| 230 | + output_filename = os.path.join(output_path, f"{clean_scalar_name}.png") |
| 231 | + |
| 232 | + # Use templates for titles and labels if available |
| 233 | + title = plot_cfg.get("title", "{scalar_key}").format(scalar_key=scalar_key) |
| 234 | + x_label = plot_cfg.get("x_label", "Step") |
| 235 | + y_label = plot_cfg.get("y_label_template", "{scalar_key}").format(scalar_key=scalar_key) |
| 236 | + |
| 237 | + plot_confidence_interval( |
| 238 | + experiments_data=experiments_data_for_this_plot, |
| 239 | + title=title, |
| 240 | + x_label=x_label, |
| 241 | + y_label=y_label, |
| 242 | + output_filename=output_filename, |
| 243 | + ) |
| 244 | + logger.info("\nAll plots generated successfully.") |
| 245 | + |
| 246 | + |
| 247 | +if __name__ == "__main__": |
| 248 | + main() |
0 commit comments