-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugment_by_mask.py
More file actions
87 lines (67 loc) · 3.02 KB
/
augment_by_mask.py
File metadata and controls
87 lines (67 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from argparse import ArgumentParser
import itertools
from pathlib import Path
from typing import List
from joblib import Parallel, cpu_count, delayed
from tqdm import tqdm
import numpy as np
from dataset.dataset_interface import DatasetInterface
from utils.general_utils import split
def augment(data: List[np.array]):
rs_rgb, rs_depth, zv_rgb, zv_depth, masks = data
assert masks is not None, "mask must be present to augment data"
num_masks = masks.shape[-1]
# generate binary combinations
masks_combinations = list(itertools.product([0, 1], repeat=num_masks))
print(f"generate {len(masks_combinations)} augmentations for image")
augmented_dataset = []
for masks_combination in masks_combinations:
masks_indices = np.nonzero(masks_combination)[0]
if len(masks_indices) == 0:
continue
mask = np.expand_dims(np.sum(masks[:, :, masks_indices], axis=2) > 0, axis=2)
rs_rgb_masked = rs_rgb * mask
rs_depth_masked = np.where(mask, rs_depth, np.nan)
zv_rgb_masked = zv_rgb * mask
zv_depth_masked = np.where(mask, zv_depth, np.nan)
augmented_dataset.append([rs_rgb_masked, rs_depth_masked, zv_rgb_masked, zv_depth_masked, mask])
return augmented_dataset
def augment_files(in_dir: Path, out_dir: Path, files: List[Path]):
num_augmented_files = 0
num_files = len(files)
for idx, file in enumerate(files):
print(f"processing {file}, {idx}/{num_files}")
relative_dir_path = file.relative_to(in_dir).parent
data = DatasetInterface.load(file)
augmented_dataset = augment(data)
num_augmented_files += len(augmented_dataset)
for idx, augmented_data in enumerate(augmented_dataset):
DatasetInterface.save(
*augmented_data,
file_name=out_dir / relative_dir_path / f"{file.stem}_{idx}{file.suffix}"
)
return num_augmented_files
def main(args):
print(f"""
input directory: {args.in_dir}
output directory: {args.out_dir}
number of jobs: {args.jobs}
""")
files = DatasetInterface.get_paths_in_dir(args.in_dir, recursive=True)
jobs = args.jobs
if jobs > 1:
files_chunked = split(files, jobs)
num_augmented_files = Parallel(n_jobs=jobs)(
delayed(augment_files)(args.in_dir, args.out_dir, files_chunk) for files_chunk in files_chunked
)
else:
num_augmented_files = 0
for file in tqdm(files):
num_augmented_files += augment_files(args.in_dir, args.out_dir, [file])
print(f"Generated {np.sum(num_augmented_files)} augmented files")
if __name__ == "__main__":
argparse = ArgumentParser()
argparse.add_argument("in_dir", type=Path, help='datset, the augmentation should be computed for')
argparse.add_argument("out_dir", type=Path, help='directory the dataset including augmentations should be saved to')
argparse.add_argument("--jobs", type=int, default=cpu_count(), help='number of processes to use')
main(argparse.parse_args())