Skip to content

Commit 82e91e7

Browse files
committed
add KITTI dataset loader
1 parent fff3469 commit 82e91e7

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

kitti_dataloader.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import os
2+
import os.path
3+
import numpy as np
4+
import torch.utils.data as data
5+
import h5py
6+
import transforms
7+
8+
IMG_EXTENSIONS = [
9+
'.h5',
10+
]
11+
12+
13+
def is_image_file(filename):
14+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15+
16+
17+
def find_classes(dir):
18+
classes = [d for d in os.listdir(
19+
dir) if os.path.isdir(os.path.join(dir, d))]
20+
classes.sort()
21+
class_to_idx = {classes[i]: i for i in range(len(classes))}
22+
return classes, class_to_idx
23+
24+
25+
def make_dataset(dir, class_to_idx):
26+
print("make dataset was called.")
27+
images = []
28+
dir = os.path.expanduser(dir)
29+
print(dir)
30+
for target in sorted(os.listdir(dir)):
31+
# print(target)
32+
d = os.path.join(dir, target)
33+
34+
if not os.path.isdir(d):
35+
continue
36+
print(len(sorted(os.walk(d))))
37+
for root, _, fnames in sorted(os.walk(d)):
38+
for fname in sorted(fnames):
39+
if is_image_file(fname):
40+
path = os.path.join(root, fname)
41+
item = (path, class_to_idx[target])
42+
images.append(item)
43+
44+
return images
45+
46+
47+
def h5_loader(path):
48+
h5f = h5py.File(path, "r")
49+
rgb = np.array(h5f['rgb'])
50+
rgb = np.transpose(rgb, (1, 2, 0))
51+
depth = np.array(h5f['depth'])
52+
53+
return rgb, depth
54+
55+
56+
oheight, owidth = 228, 912 # image size after pre-processing
57+
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
58+
59+
60+
def train_transform(rgb, depth):
61+
s = np.random.uniform(1.0, 1.5) # random scaling
62+
# print("scale factor s={}".format(s))
63+
depth_np = depth / s
64+
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
65+
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
66+
67+
# perform 1st part of data augmentation
68+
transform = transforms.Compose([
69+
transforms.Crop(130, 10, 240, 1200),
70+
transforms.Rotate(angle),
71+
transforms.Resize(s),
72+
transforms.CenterCrop((oheight, owidth)),
73+
transforms.HorizontalFlip(do_flip)
74+
])
75+
rgb_np = transform(rgb)
76+
77+
# random color jittering
78+
rgb_np = color_jitter(rgb_np)
79+
80+
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
81+
# Scipy affine_transform produced RuntimeError when the depth map was
82+
# given as a 'numpy.ndarray'
83+
depth_np = np.asfarray(depth_np, dtype='float32')
84+
depth_np = transform(depth_np)
85+
86+
return rgb_np, depth_np
87+
88+
89+
def val_transform(rgb, depth):
90+
depth_np = depth
91+
92+
# perform 1st part of data augmentation
93+
transform = transforms.Compose([
94+
transforms.Crop(130, 10, 240, 1200),
95+
transforms.CenterCrop((oheight, owidth)),
96+
])
97+
rgb_np = transform(rgb)
98+
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
99+
depth_np = np.asfarray(depth_np, dtype='float32')
100+
depth_np = transform(depth_np)
101+
102+
return rgb_np, depth_np
103+
104+
105+
def rgb2grayscale(rgb):
106+
return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114
107+
108+
109+
to_tensor = transforms.ToTensor()
110+
111+
112+
class KITTIDataset(data.Dataset):
113+
modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd'
114+
115+
def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
116+
classes, class_to_idx = find_classes(root)
117+
imgs = make_dataset(root, class_to_idx)
118+
if len(imgs) == 0:
119+
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
120+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
121+
122+
self.root = root
123+
self.imgs = imgs
124+
self.classes = classes
125+
self.class_to_idx = class_to_idx
126+
if type == 'train':
127+
self.transform = train_transform
128+
elif type == 'val':
129+
self.transform = val_transform
130+
else:
131+
raise (RuntimeError("Invalid dataset type: " + type + "\n"
132+
"Supported dataset types are: train, val"))
133+
self.loader = loader
134+
self.sparsifier = sparsifier
135+
136+
if modality in self.modality_names:
137+
self.modality = modality
138+
else:
139+
raise (RuntimeError("Invalid modality type: " + modality + "\n"
140+
"Supported dataset types are: " + ''.join(self.modality_names)))
141+
142+
def create_sparse_depth(self, rgb, depth):
143+
if self.sparsifier is None:
144+
return depth
145+
else:
146+
mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
147+
sparse_depth = np.zeros(depth.shape)
148+
sparse_depth[mask_keep] = depth[mask_keep]
149+
return sparse_depth
150+
151+
def create_rgbd(self, rgb, depth):
152+
sparse_depth = self.create_sparse_depth(rgb, depth)
153+
# rgbd = np.dstack((rgb[:,:,0], rgb[:,:,1], rgb[:,:,2], sparse_depth))
154+
rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
155+
return rgbd
156+
157+
def __getraw__(self, index):
158+
"""
159+
Args:
160+
index (int): Index
161+
162+
Returns:
163+
tuple: (rgb, depth) the raw data.
164+
"""
165+
path, target = self.imgs[index]
166+
rgb, depth = self.loader(path)
167+
return rgb, depth
168+
169+
def __get_all_item__(self, index):
170+
"""
171+
Args:
172+
index (int): Index
173+
174+
Returns:
175+
tuple: (input_tensor, depth_tensor, input_np, depth_np)
176+
"""
177+
rgb, depth = self.__getraw__(index)
178+
if self.transform is not None:
179+
rgb_np, depth_np = self.transform(rgb, depth)
180+
else:
181+
raise(RuntimeError("transform not defined"))
182+
183+
if self.modality == 'rgb':
184+
input_np = rgb_np
185+
elif self.modality == 'rgbd':
186+
input_np = self.create_rgbd(rgb_np, depth_np)
187+
elif self.modality == 'd':
188+
input_np = self.create_sparse_depth(rgb_np, depth_np)
189+
190+
input_tensor = to_tensor(input_np)
191+
while input_tensor.dim() < 3:
192+
input_tensor = input_tensor.unsqueeze(0)
193+
depth_tensor = to_tensor(depth_np)
194+
depth_tensor = depth_tensor.unsqueeze(0)
195+
196+
return input_tensor, depth_tensor, input_np, depth_np
197+
198+
def __getitem__(self, index):
199+
"""
200+
Args:
201+
index (int): Index
202+
203+
Returns:
204+
tuple: (input_tensor, depth_tensor)
205+
"""
206+
input_tensor, depth_tensor, input_np, depth_np = self.__get_all_item__(
207+
index)
208+
209+
return input_tensor, depth_tensor
210+
211+
def __len__(self):
212+
return len(self.imgs)

0 commit comments

Comments
 (0)