-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_sets.py
More file actions
42 lines (33 loc) · 1.46 KB
/
generate_sets.py
File metadata and controls
42 lines (33 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from argparse import ArgumentParser
import json
import os
from pathlib import Path
import random
from typing import List
ROOT_DIR = Path(__file__).parent.parent.parent
def save_paths_to_json(paths: List[Path], basepath: Path, file_name: str):
paths = [path.relative_to(basepath).as_posix() for path in paths]
dataset = {
'base_path': basepath.as_posix(),
'files': paths
}
with open(f"{basepath}/{file_name}.json", 'w') as f:
json.dump(dataset, f)
def main(args):
dataset_files = list(args.dataset_path.rglob("*.npz"))
random.shuffle(dataset_files)
ds_len = len(dataset_files)
test_border = int(ds_len // (1 / args.test_percentage))
val_border = int(ds_len // (1 / (args.test_percentage + args.val_percentage)))
test_ds_files = dataset_files[:test_border]
val_ds_files = dataset_files[test_border:val_border]
train_ds_files = dataset_files[val_border:]
save_paths_to_json(test_ds_files, args.dataset_path, "test_dataset")
save_paths_to_json(val_ds_files, args.dataset_path, "val_dataset")
save_paths_to_json(train_ds_files, args.dataset_path, "train_dataset")
if __name__ == "__main__":
argparse = ArgumentParser()
argparse.add_argument("dataset_path", type=Path, help="path to the dataset containing npz for training")
argparse.add_argument("--test-percentage", type=float, default=0.05)
argparse.add_argument("--val-percentage", type=float, default=0.05)
main(argparse.parse_args())