Skip to content

Commit fd2662f

Browse files
committed
Add runner argument to describe datasets
1 parent e76e463 commit fd2662f

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

sklbench/runner/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def add_runner_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentPa
130130
action="store_true",
131131
help="Load all requested datasets in parallel before running benchmarks.",
132132
)
133+
parser.add_argument(
134+
"--describe-datasets",
135+
default=False,
136+
action="store_true",
137+
help="Load all requested datasets in parallel and show their parameters.",
138+
)
133139
# workflow control
134140
parser.add_argument(
135141
"--exit-on-error",

sklbench/runner/implementation.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717

1818
import argparse
19+
import gc
1920
import json
21+
import sys
2022
from multiprocessing import Pool
2123
from typing import Dict, List, Tuple, Union
2224

@@ -94,15 +96,22 @@ def run_benchmarks(args: argparse.Namespace) -> int:
9496
bench_cases = early_filtering(bench_cases, param_filters)
9597

9698
# prefetch datasets
97-
if args.prefetch_datasets:
99+
if args.prefetch_datasets or args.describe_datasets:
98100
# trick: get unique dataset names only to avoid loading of same dataset
99101
# by different cases/processes
100102
dataset_cases = {get_data_name(case): case for case in bench_cases}
101103
logger.debug(f"Unique dataset names to load:\n{list(dataset_cases.keys())}")
102104
n_proc = min([16, cpu_count(), len(dataset_cases)])
103105
logger.info(f"Prefetching datasets with {n_proc} processes")
104106
with Pool(n_proc) as pool:
105-
pool.map(load_data, dataset_cases.values())
107+
datasets = pool.map(load_data, dataset_cases.values())
108+
if args.describe_datasets:
109+
for ((data, data_description), data_name) in zip(datasets, dataset_cases.keys()):
110+
print(f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}")
111+
sys.exit(0)
112+
# free memory used by prefetched datasets
113+
del datasets
114+
gc.collect()
106115

107116
# run bench_cases
108117
return_code, result = call_benchmarks(

0 commit comments

Comments
 (0)