|
16 | 16 |
|
17 | 17 |
|
18 | 18 | import argparse
|
| 19 | +import gc |
19 | 20 | import json
|
| 21 | +import sys |
20 | 22 | from multiprocessing import Pool
|
21 | 23 | from typing import Dict, List, Tuple, Union
|
22 | 24 |
|
@@ -94,15 +96,22 @@ def run_benchmarks(args: argparse.Namespace) -> int:
|
94 | 96 | bench_cases = early_filtering(bench_cases, param_filters)
|
95 | 97 |
|
96 | 98 | # prefetch datasets
|
97 |
| - if args.prefetch_datasets: |
| 99 | + if args.prefetch_datasets or args.describe_datasets: |
98 | 100 | # trick: get unique dataset names only to avoid loading of same dataset
|
99 | 101 | # by different cases/processes
|
100 | 102 | dataset_cases = {get_data_name(case): case for case in bench_cases}
|
101 | 103 | logger.debug(f"Unique dataset names to load:\n{list(dataset_cases.keys())}")
|
102 | 104 | n_proc = min([16, cpu_count(), len(dataset_cases)])
|
103 | 105 | logger.info(f"Prefetching datasets with {n_proc} processes")
|
104 | 106 | with Pool(n_proc) as pool:
|
105 |
| - pool.map(load_data, dataset_cases.values()) |
| 107 | + datasets = pool.map(load_data, dataset_cases.values()) |
| 108 | + if args.describe_datasets: |
| 109 | + for ((data, data_description), data_name) in zip(datasets, dataset_cases.keys()): |
| 110 | + print(f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}") |
| 111 | + sys.exit(0) |
| 112 | + # free memory used by prefetched datasets |
| 113 | + del datasets |
| 114 | + gc.collect() |
106 | 115 |
|
107 | 116 | # run bench_cases
|
108 | 117 | return_code, result = call_benchmarks(
|
|
0 commit comments