|
| 1 | +from collections import defaultdict |
| 2 | +from copy import deepcopy |
| 3 | +import os |
| 4 | +import numpy as np |
| 5 | +import multiprocessing |
| 6 | +from PIL import Image |
| 7 | + |
| 8 | + |
| 9 | +def worker(args): |
| 10 | + keys_list, fns, save_dir, src_dir = args |
| 11 | + fp = open(os.path.join(src_dir, f"diode_{keys_list[0].split('/')[-1]}-{keys_list[-1].split('/')[-1]}.txt"), "w") |
| 12 | + cnt = 0 |
| 13 | + for dirpath in keys_list: |
| 14 | + for i, f in enumerate(fns[dirpath]): |
| 15 | + dest_dir = os.path.join(save_dir, dirpath) |
| 16 | + f = f.split(".")[0] |
| 17 | + |
| 18 | + depth = np.load(os.path.join(src_dir, dirpath, f+"_depth.npy")).astype(np.float32) |
| 19 | + depth_mask = np.load(os.path.join(src_dir, dirpath, f+"_depth_mask.npy")) |
| 20 | + depth[depth_mask < 1e-6] = 0 |
| 21 | + depth = np.clip(depth, 0.0, 100) |
| 22 | + |
| 23 | + if not os.path.exists(dest_dir): |
| 24 | + os.makedirs(dest_dir, exist_ok=True) |
| 25 | + depth_dst = os.path.join(dest_dir, f+"_depth.png") |
| 26 | + img_dst = os.path.join(dest_dir, f+".png") |
| 27 | + img_src = os.path.join(src_dir, dirpath, f+".png") |
| 28 | + |
| 29 | + Image.fromarray((256.0 * depth[..., 0]).astype(np.uint16)).save(depth_dst) |
| 30 | + os.system(f"cp {img_src} {img_dst}") |
| 31 | + fp.write(f"{img_dst.split(save_dir)[-1].strip('/')} {depth_dst.split(save_dir)[-1].strip('/')}\n") |
| 32 | + cnt += 1 |
| 33 | + fp.close() |
| 34 | + return cnt |
| 35 | + |
| 36 | + |
| 37 | +def main_worker(kind): |
| 38 | + local_path_to_splits = (os.environ["TMPDIR"]) |
| 39 | + output_save_path = (os.environ["TMPDIR"]+"/diode") |
| 40 | + os.makedirs(output_save_path, exist_ok=True) |
| 41 | + folder = os.path.join(local_path_to_splits, kind) |
| 42 | + n_cpus = multiprocessing.cpu_count() |
| 43 | + |
| 44 | + fns = defaultdict(list) |
| 45 | + for (dirpath, dirnames, filenames) in os.walk(folder): |
| 46 | + dirpath = dirpath.split(local_path_to_splits)[-1].strip("/") |
| 47 | + if filenames: |
| 48 | + fns[dirpath].extend(filenames) |
| 49 | + |
| 50 | + for dirpath, filenames in fns.items(): |
| 51 | + for i, f in enumerate(filenames): |
| 52 | + if "txt" in f: |
| 53 | + fns[dirpath].remove(f) |
| 54 | + continue |
| 55 | + fns[dirpath][i] = f.split(".")[0].replace("_depth", "").replace("_mask", "") |
| 56 | + |
| 57 | + fns[dirpath] = np.unique(fns[dirpath]) |
| 58 | + |
| 59 | + chunk_s = len(fns) // n_cpus + 1 |
| 60 | + keys_list = list(fns.keys()) |
| 61 | + keys_list = [keys_list[i:min(i+chunk_s, len(keys_list))] for i in range(0, len(keys_list), chunk_s)] |
| 62 | + |
| 63 | + with multiprocessing.Pool(n_cpus) as p: |
| 64 | + res = p.imap_unordered(worker, zip(keys_list, [deepcopy(fns)] * len(keys_list), [output_save_path] * len(keys_list), [local_path_to_splits] * len(keys_list))) |
| 65 | + print("TOT: ", sum(res)) |
| 66 | + |
| 67 | + # merge the txt files into the final ones |
| 68 | + fp_all = open(os.path.join(output_save_path, f"diode_{kind}.txt"), "w") |
| 69 | + fp_indoor = open(os.path.join(output_save_path, f"diode_indoor_{kind}.txt"), "w") |
| 70 | + fp_outdoor = open(os.path.join(output_save_path, f"diode_outdoor_{kind}.txt"), "w") |
| 71 | + for text in os.listdir(local_path_to_splits): |
| 72 | + if "txt" not in text: |
| 73 | + continue |
| 74 | + with open(os.path.join(local_path_to_splits, text)) as f: |
| 75 | + for line in f: |
| 76 | + fp_all.write(line) |
| 77 | + if "indoor" in line: |
| 78 | + fp_indoor.write(line) |
| 79 | + if "outdoor" in line: |
| 80 | + fp_outdoor.write(line) |
| 81 | + os.remove(os.path.join(local_path_to_splits, text)) |
| 82 | + |
| 83 | + fp_all.close() |
| 84 | + fp_indoor.close() |
| 85 | + fp_outdoor.close() |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == '__main__': |
| 89 | + temp_dir = os.environ.get("TMPDIR", os.environ["HOME"]) |
| 90 | + os.environ["TMPDIR"] = temp_dir |
| 91 | + for kind in ["val", "train"]: |
| 92 | + if not os.path.exists(os.path.join(temp_dir, kind+'.tar.gz')): |
| 93 | + os.system(f"wget 'http://diode-dataset.s3.amazonaws.com/{kind}.tar.gz' -P {temp_dir}") |
| 94 | + os.system(f"tar -xzf {os.path.join(temp_dir, kind+'.tar.gz')} -C {temp_dir}") |
| 95 | + main_worker(kind) |
| 96 | + # if save space |
| 97 | + # os.remove(os.path.join(temp_dir, kind+'.tar.gz')) |
0 commit comments