3939
4040@dataclass
4141class BenchmarkConfig :
42- num_train_samples : int = 10000
42+ num_train_samples : int = 5000
4343 """Number of training samples to generate from the Lorenz63 dataset."""
4444 num_test_samples : int = 1000
4545 """Number of test samples to generate from the Lorenz63 dataset."""
46- num_repeats : int = 10
46+ num_repeats : int = 3
4747 """Number of times to repeat each benchmark for statistics."""
4848 rank : int = 25
4949 """Rank of the Koopman operator approximation."""
@@ -63,8 +63,10 @@ class BenchmarkConfig:
6363 """Base random seed for reproducibility."""
6464 save_json : bool = True
6565 """Whether to save results to JSON file."""
66- make_plots : bool = False
66+ make_plots : bool = True
6767 """Whether to generate plots."""
68+ results_json_path : str | None = None
69+ """Path to load existing results JSON for plotting only."""
6870
6971
7072def timer (func ):
@@ -291,9 +293,9 @@ def plot_benchmark(results: dict, metric: str, filename: str, color: str):
291293
292294 # Formatting
293295 if metric == "fit_time" :
294- label_text = f"{ val :.2f} ± { iqr :.2f } s" if is_valid else "FAILED"
296+ label_text = f"{ val :.2f} s" if is_valid else "FAILED"
295297 else :
296- label_text = f"{ val :.4f} ± { iqr :.4f } " if is_valid else "N/A"
298+ label_text = f"{ val :.4f} " if is_valid else "N/A"
297299
298300 processed .append (
299301 {
@@ -318,7 +320,6 @@ def plot_benchmark(results: dict, metric: str, filename: str, color: str):
318320 bars = ax .barh (
319321 names ,
320322 values ,
321- xerr = errors ,
322323 color = color ,
323324 height = 0.6 ,
324325 capsize = 5 ,
@@ -423,4 +424,9 @@ def run_benchmarks(config: BenchmarkConfig) -> None:
423424
424425if __name__ == "__main__" :
425426 configs = tyro .cli (BenchmarkConfig )
426- run_benchmarks (configs )
427+ if configs .results_json_path :
428+ with open (configs .results_json_path , "r" ) as f :
429+ results = json .load (f )
430+ plot_benchmark (results , "fit_time" , "fit_time_benchmarks.svg" , "#2A7E68" )
431+ else :
432+ run_benchmarks (configs )
0 commit comments