Skip to content

Commit fac6260

Browse files
committed
Last tune on the default settings
1 parent 5b81499 commit fac6260

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

benchmarks/fit_time.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939

4040
@dataclass
4141
class BenchmarkConfig:
42-
num_train_samples: int = 10000
42+
num_train_samples: int = 5000
4343
"""Number of training samples to generate from the Lorenz63 dataset."""
4444
num_test_samples: int = 1000
4545
"""Number of test samples to generate from the Lorenz63 dataset."""
46-
num_repeats: int = 10
46+
num_repeats: int = 3
4747
"""Number of times to repeat each benchmark for statistics."""
4848
rank: int = 25
4949
"""Rank of the Koopman operator approximation."""
@@ -63,8 +63,10 @@ class BenchmarkConfig:
6363
"""Base random seed for reproducibility."""
6464
save_json: bool = True
6565
"""Whether to save results to JSON file."""
66-
make_plots: bool = False
66+
make_plots: bool = True
6767
"""Whether to generate plots."""
68+
results_json_path: str | None = None
69+
"""Path to load existing results JSON for plotting only."""
6870

6971

7072
def timer(func):
@@ -291,9 +293,9 @@ def plot_benchmark(results: dict, metric: str, filename: str, color: str):
291293

292294
# Formatting
293295
if metric == "fit_time":
294-
label_text = f"{val:.2f}±{iqr:.2f}s" if is_valid else "FAILED"
296+
label_text = f"{val:.2f}s" if is_valid else "FAILED"
295297
else:
296-
label_text = f"{val:.4f}±{iqr:.4f}" if is_valid else "N/A"
298+
label_text = f"{val:.4f}" if is_valid else "N/A"
297299

298300
processed.append(
299301
{
@@ -318,7 +320,6 @@ def plot_benchmark(results: dict, metric: str, filename: str, color: str):
318320
bars = ax.barh(
319321
names,
320322
values,
321-
xerr=errors,
322323
color=color,
323324
height=0.6,
324325
capsize=5,
@@ -423,4 +424,9 @@ def run_benchmarks(config: BenchmarkConfig) -> None:
423424

424425
if __name__ == "__main__":
425426
configs = tyro.cli(BenchmarkConfig)
426-
run_benchmarks(configs)
427+
if configs.results_json_path:
428+
with open(configs.results_json_path, "r") as f:
429+
results = json.load(f)
430+
plot_benchmark(results, "fit_time", "fit_time_benchmarks.svg", "#2A7E68")
431+
else:
432+
run_benchmarks(configs)

0 commit comments

Comments
 (0)