|
| 1 | +import os |
| 2 | +import os.path |
| 3 | +import numpy as np |
| 4 | +import torch.utils.data as data |
| 5 | +import h5py |
| 6 | +import transforms |
| 7 | + |
| 8 | +IMG_EXTENSIONS = [ |
| 9 | + '.h5', |
| 10 | +] |
| 11 | + |
| 12 | + |
| 13 | +def is_image_file(filename): |
| 14 | + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) |
| 15 | + |
| 16 | + |
| 17 | +def find_classes(dir): |
| 18 | + classes = [d for d in os.listdir( |
| 19 | + dir) if os.path.isdir(os.path.join(dir, d))] |
| 20 | + classes.sort() |
| 21 | + class_to_idx = {classes[i]: i for i in range(len(classes))} |
| 22 | + return classes, class_to_idx |
| 23 | + |
| 24 | + |
| 25 | +def make_dataset(dir, class_to_idx): |
| 26 | + print("make dataset was called.") |
| 27 | + images = [] |
| 28 | + dir = os.path.expanduser(dir) |
| 29 | + print(dir) |
| 30 | + for target in sorted(os.listdir(dir)): |
| 31 | + # print(target) |
| 32 | + d = os.path.join(dir, target) |
| 33 | + |
| 34 | + if not os.path.isdir(d): |
| 35 | + continue |
| 36 | + print(len(sorted(os.walk(d)))) |
| 37 | + for root, _, fnames in sorted(os.walk(d)): |
| 38 | + for fname in sorted(fnames): |
| 39 | + if is_image_file(fname): |
| 40 | + path = os.path.join(root, fname) |
| 41 | + item = (path, class_to_idx[target]) |
| 42 | + images.append(item) |
| 43 | + |
| 44 | + return images |
| 45 | + |
| 46 | + |
| 47 | +def h5_loader(path): |
| 48 | + h5f = h5py.File(path, "r") |
| 49 | + rgb = np.array(h5f['rgb']) |
| 50 | + rgb = np.transpose(rgb, (1, 2, 0)) |
| 51 | + depth = np.array(h5f['depth']) |
| 52 | + |
| 53 | + return rgb, depth |
| 54 | + |
| 55 | + |
| 56 | +oheight, owidth = 228, 912 # image size after pre-processing |
| 57 | +color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) |
| 58 | + |
| 59 | + |
| 60 | +def train_transform(rgb, depth): |
| 61 | + s = np.random.uniform(1.0, 1.5) # random scaling |
| 62 | + # print("scale factor s={}".format(s)) |
| 63 | + depth_np = depth / s |
| 64 | + angle = np.random.uniform(-5.0, 5.0) # random rotation degrees |
| 65 | + do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip |
| 66 | + |
| 67 | + # perform 1st part of data augmentation |
| 68 | + transform = transforms.Compose([ |
| 69 | + transforms.Crop(130, 10, 240, 1200), |
| 70 | + transforms.Rotate(angle), |
| 71 | + transforms.Resize(s), |
| 72 | + transforms.CenterCrop((oheight, owidth)), |
| 73 | + transforms.HorizontalFlip(do_flip) |
| 74 | + ]) |
| 75 | + rgb_np = transform(rgb) |
| 76 | + |
| 77 | + # random color jittering |
| 78 | + rgb_np = color_jitter(rgb_np) |
| 79 | + |
| 80 | + rgb_np = np.asfarray(rgb_np, dtype='float') / 255 |
| 81 | + # Scipy affine_transform produced RuntimeError when the depth map was |
| 82 | + # given as a 'numpy.ndarray' |
| 83 | + depth_np = np.asfarray(depth_np, dtype='float32') |
| 84 | + depth_np = transform(depth_np) |
| 85 | + |
| 86 | + return rgb_np, depth_np |
| 87 | + |
| 88 | + |
| 89 | +def val_transform(rgb, depth): |
| 90 | + depth_np = depth |
| 91 | + |
| 92 | + # perform 1st part of data augmentation |
| 93 | + transform = transforms.Compose([ |
| 94 | + transforms.Crop(130, 10, 240, 1200), |
| 95 | + transforms.CenterCrop((oheight, owidth)), |
| 96 | + ]) |
| 97 | + rgb_np = transform(rgb) |
| 98 | + rgb_np = np.asfarray(rgb_np, dtype='float') / 255 |
| 99 | + depth_np = np.asfarray(depth_np, dtype='float32') |
| 100 | + depth_np = transform(depth_np) |
| 101 | + |
| 102 | + return rgb_np, depth_np |
| 103 | + |
| 104 | + |
| 105 | +def rgb2grayscale(rgb): |
| 106 | + return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114 |
| 107 | + |
| 108 | + |
| 109 | +to_tensor = transforms.ToTensor() |
| 110 | + |
| 111 | + |
| 112 | +class KITTIDataset(data.Dataset): |
| 113 | + modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd' |
| 114 | + |
| 115 | + def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader): |
| 116 | + classes, class_to_idx = find_classes(root) |
| 117 | + imgs = make_dataset(root, class_to_idx) |
| 118 | + if len(imgs) == 0: |
| 119 | + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" |
| 120 | + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) |
| 121 | + |
| 122 | + self.root = root |
| 123 | + self.imgs = imgs |
| 124 | + self.classes = classes |
| 125 | + self.class_to_idx = class_to_idx |
| 126 | + if type == 'train': |
| 127 | + self.transform = train_transform |
| 128 | + elif type == 'val': |
| 129 | + self.transform = val_transform |
| 130 | + else: |
| 131 | + raise (RuntimeError("Invalid dataset type: " + type + "\n" |
| 132 | + "Supported dataset types are: train, val")) |
| 133 | + self.loader = loader |
| 134 | + self.sparsifier = sparsifier |
| 135 | + |
| 136 | + if modality in self.modality_names: |
| 137 | + self.modality = modality |
| 138 | + else: |
| 139 | + raise (RuntimeError("Invalid modality type: " + modality + "\n" |
| 140 | + "Supported dataset types are: " + ''.join(self.modality_names))) |
| 141 | + |
| 142 | + def create_sparse_depth(self, rgb, depth): |
| 143 | + if self.sparsifier is None: |
| 144 | + return depth |
| 145 | + else: |
| 146 | + mask_keep = self.sparsifier.dense_to_sparse(rgb, depth) |
| 147 | + sparse_depth = np.zeros(depth.shape) |
| 148 | + sparse_depth[mask_keep] = depth[mask_keep] |
| 149 | + return sparse_depth |
| 150 | + |
| 151 | + def create_rgbd(self, rgb, depth): |
| 152 | + sparse_depth = self.create_sparse_depth(rgb, depth) |
| 153 | + # rgbd = np.dstack((rgb[:,:,0], rgb[:,:,1], rgb[:,:,2], sparse_depth)) |
| 154 | + rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2) |
| 155 | + return rgbd |
| 156 | + |
| 157 | + def __getraw__(self, index): |
| 158 | + """ |
| 159 | + Args: |
| 160 | + index (int): Index |
| 161 | +
|
| 162 | + Returns: |
| 163 | + tuple: (rgb, depth) the raw data. |
| 164 | + """ |
| 165 | + path, target = self.imgs[index] |
| 166 | + rgb, depth = self.loader(path) |
| 167 | + return rgb, depth |
| 168 | + |
| 169 | + def __get_all_item__(self, index): |
| 170 | + """ |
| 171 | + Args: |
| 172 | + index (int): Index |
| 173 | +
|
| 174 | + Returns: |
| 175 | + tuple: (input_tensor, depth_tensor, input_np, depth_np) |
| 176 | + """ |
| 177 | + rgb, depth = self.__getraw__(index) |
| 178 | + if self.transform is not None: |
| 179 | + rgb_np, depth_np = self.transform(rgb, depth) |
| 180 | + else: |
| 181 | + raise(RuntimeError("transform not defined")) |
| 182 | + |
| 183 | + if self.modality == 'rgb': |
| 184 | + input_np = rgb_np |
| 185 | + elif self.modality == 'rgbd': |
| 186 | + input_np = self.create_rgbd(rgb_np, depth_np) |
| 187 | + elif self.modality == 'd': |
| 188 | + input_np = self.create_sparse_depth(rgb_np, depth_np) |
| 189 | + |
| 190 | + input_tensor = to_tensor(input_np) |
| 191 | + while input_tensor.dim() < 3: |
| 192 | + input_tensor = input_tensor.unsqueeze(0) |
| 193 | + depth_tensor = to_tensor(depth_np) |
| 194 | + depth_tensor = depth_tensor.unsqueeze(0) |
| 195 | + |
| 196 | + return input_tensor, depth_tensor, input_np, depth_np |
| 197 | + |
| 198 | + def __getitem__(self, index): |
| 199 | + """ |
| 200 | + Args: |
| 201 | + index (int): Index |
| 202 | +
|
| 203 | + Returns: |
| 204 | + tuple: (input_tensor, depth_tensor) |
| 205 | + """ |
| 206 | + input_tensor, depth_tensor, input_np, depth_np = self.__get_all_item__( |
| 207 | + index) |
| 208 | + |
| 209 | + return input_tensor, depth_tensor |
| 210 | + |
| 211 | + def __len__(self): |
| 212 | + return len(self.imgs) |
0 commit comments