Skip to content

Commit 633657d

Browse files
author
Fangchang Ma
committed
Merge branch 'AkariAsai-master'
2 parents fff3469 + 68460f3 commit 633657d

File tree

8 files changed

+317
-231
lines changed

8 files changed

+317
-231
lines changed

nyu_dataloader.py renamed to dataloaders/dataloader.py

Lines changed: 52 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
import numpy as np
44
import torch.utils.data as data
55
import h5py
6-
import transforms
6+
import dataloaders.transforms as transforms
77

8-
IMG_EXTENSIONS = [
9-
'.h5',
10-
]
8+
IMG_EXTENSIONS = ['.h5',]
119

1210
def is_image_file(filename):
1311
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
@@ -22,106 +20,60 @@ def make_dataset(dir, class_to_idx):
2220
images = []
2321
dir = os.path.expanduser(dir)
2422
for target in sorted(os.listdir(dir)):
25-
# print(target)
2623
d = os.path.join(dir, target)
2724
if not os.path.isdir(d):
2825
continue
29-
3026
for root, _, fnames in sorted(os.walk(d)):
3127
for fname in sorted(fnames):
3228
if is_image_file(fname):
3329
path = os.path.join(root, fname)
3430
item = (path, class_to_idx[target])
3531
images.append(item)
36-
3732
return images
3833

3934
def h5_loader(path):
4035
h5f = h5py.File(path, "r")
4136
rgb = np.array(h5f['rgb'])
4237
rgb = np.transpose(rgb, (1, 2, 0))
4338
depth = np.array(h5f['depth'])
44-
4539
return rgb, depth
4640

47-
iheight, iwidth = 480, 640 # raw image size
48-
oheight, owidth = 228, 304 # image size after pre-processing
49-
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
50-
51-
def train_transform(rgb, depth):
52-
s = np.random.uniform(1.0, 1.5) # random scaling
53-
# print("scale factor s={}".format(s))
54-
depth_np = depth / s
55-
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
56-
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
57-
58-
# perform 1st part of data augmentation
59-
transform = transforms.Compose([
60-
transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation is very slow
61-
transforms.Rotate(angle),
62-
transforms.Resize(s),
63-
transforms.CenterCrop((oheight, owidth)),
64-
transforms.HorizontalFlip(do_flip)
65-
])
66-
rgb_np = transform(rgb)
67-
68-
# random color jittering
69-
rgb_np = color_jitter(rgb_np)
70-
71-
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
72-
depth_np = transform(depth_np)
73-
74-
return rgb_np, depth_np
75-
76-
def val_transform(rgb, depth):
77-
depth_np = depth
78-
79-
# perform 1st part of data augmentation
80-
transform = transforms.Compose([
81-
transforms.Resize(240.0 / iheight),
82-
transforms.CenterCrop((oheight, owidth)),
83-
])
84-
rgb_np = transform(rgb)
85-
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
86-
depth_np = transform(depth_np)
87-
88-
return rgb_np, depth_np
89-
90-
def rgb2grayscale(rgb):
91-
return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114
92-
41+
# def rgb2grayscale(rgb):
42+
# return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114
9343

9444
to_tensor = transforms.ToTensor()
9545

96-
class NYUDataset(data.Dataset):
46+
class MyDataloader(data.Dataset):
9747
modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd'
9848

9949
def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
10050
classes, class_to_idx = find_classes(root)
10151
imgs = make_dataset(root, class_to_idx)
102-
if len(imgs) == 0:
103-
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
104-
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
105-
52+
assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n"
53+
print("Found {} images in {} folder.".format(len(imgs), type))
10654
self.root = root
10755
self.imgs = imgs
10856
self.classes = classes
10957
self.class_to_idx = class_to_idx
11058
if type == 'train':
111-
self.transform = train_transform
59+
self.transform = self.train_transform
11260
elif type == 'val':
113-
self.transform = val_transform
61+
self.transform = self.val_transform
11462
else:
11563
raise (RuntimeError("Invalid dataset type: " + type + "\n"
11664
"Supported dataset types are: train, val"))
11765
self.loader = loader
11866
self.sparsifier = sparsifier
11967

120-
if modality in self.modality_names:
121-
self.modality = modality
122-
else:
123-
raise (RuntimeError("Invalid modality type: " + modality + "\n"
124-
"Supported dataset types are: " + ''.join(self.modality_names)))
68+
assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \
69+
"Supported dataset types are: " + ''.join(self.modality_names)
70+
self.modality = modality
71+
72+
def train_transform(self, rgb, depth):
73+
raise (RuntimeError("train_transform() is not implemented. "))
74+
75+
def val_transform(rgb, depth):
76+
raise (RuntimeError("val_transform() is not implemented."))
12577

12678
def create_sparse_depth(self, rgb, depth):
12779
if self.sparsifier is None:
@@ -134,7 +86,6 @@ def create_sparse_depth(self, rgb, depth):
13486

13587
def create_rgbd(self, rgb, depth):
13688
sparse_depth = self.create_sparse_depth(rgb, depth)
137-
# rgbd = np.dstack((rgb[:,:,0], rgb[:,:,1], rgb[:,:,2], sparse_depth))
13889
rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
13990
return rgbd
14091

@@ -150,14 +101,7 @@ def __getraw__(self, index):
150101
rgb, depth = self.loader(path)
151102
return rgb, depth
152103

153-
def __get_all_item__(self, index):
154-
"""
155-
Args:
156-
index (int): Index
157-
158-
Returns:
159-
tuple: (input_tensor, depth_tensor, input_np, depth_np)
160-
"""
104+
def __getitem__(self, index):
161105
rgb, depth = self.__getraw__(index)
162106
if self.transform is not None:
163107
rgb_np, depth_np = self.transform(rgb, depth)
@@ -181,19 +125,40 @@ def __get_all_item__(self, index):
181125
depth_tensor = to_tensor(depth_np)
182126
depth_tensor = depth_tensor.unsqueeze(0)
183127

184-
return input_tensor, depth_tensor, input_np, depth_np
185-
186-
def __getitem__(self, index):
187-
"""
188-
Args:
189-
index (int): Index
190-
191-
Returns:
192-
tuple: (input_tensor, depth_tensor)
193-
"""
194-
input_tensor, depth_tensor, input_np, depth_np = self.__get_all_item__(index)
195-
196128
return input_tensor, depth_tensor
197129

198130
def __len__(self):
199131
return len(self.imgs)
132+
133+
# def __get_all_item__(self, index):
134+
# """
135+
# Args:
136+
# index (int): Index
137+
138+
# Returns:
139+
# tuple: (input_tensor, depth_tensor, input_np, depth_np)
140+
# """
141+
# rgb, depth = self.__getraw__(index)
142+
# if self.transform is not None:
143+
# rgb_np, depth_np = self.transform(rgb, depth)
144+
# else:
145+
# raise(RuntimeError("transform not defined"))
146+
147+
# # color normalization
148+
# # rgb_tensor = normalize_rgb(rgb_tensor)
149+
# # rgb_np = normalize_np(rgb_np)
150+
151+
# if self.modality == 'rgb':
152+
# input_np = rgb_np
153+
# elif self.modality == 'rgbd':
154+
# input_np = self.create_rgbd(rgb_np, depth_np)
155+
# elif self.modality == 'd':
156+
# input_np = self.create_sparse_depth(rgb_np, depth_np)
157+
158+
# input_tensor = to_tensor(input_np)
159+
# while input_tensor.dim() < 3:
160+
# input_tensor = input_tensor.unsqueeze(0)
161+
# depth_tensor = to_tensor(depth_np)
162+
# depth_tensor = depth_tensor.unsqueeze(0)
163+
164+
# return input_tensor, depth_tensor, input_np, depth_np
File renamed without changes.

dataloaders/kitti_dataloader.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import dataloaders.transforms as transforms
3+
from dataloaders.dataloader import MyDataloader
4+
5+
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
6+
7+
class KITTIDataset(MyDataloader):
8+
def __init__(self, root, type, sparsifier=None, modality='rgb'):
9+
super(KITTIDataset, self).__init__(root, type, sparsifier=None, modality='rgb')
10+
self.output_size = (228, 912)
11+
12+
def train_transform(self, rgb, depth):
13+
s = np.random.uniform(1.0, 1.5) # random scaling
14+
depth_np = depth / s
15+
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
16+
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
17+
18+
# perform 1st step of data augmentation
19+
transform = transforms.Compose([
20+
transforms.Crop(130, 10, 240, 1200),
21+
transforms.Rotate(angle),
22+
transforms.Resize(s),
23+
transforms.CenterCrop(self.output_size),
24+
transforms.HorizontalFlip(do_flip)
25+
])
26+
rgb_np = transform(rgb)
27+
rgb_np = color_jitter(rgb_np) # random color jittering
28+
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
29+
# Scipy affine_transform produced RuntimeError when the depth map was
30+
# given as a 'numpy.ndarray'
31+
depth_np = np.asfarray(depth_np, dtype='float32')
32+
depth_np = transform(depth_np)
33+
34+
return rgb_np, depth_np
35+
36+
def val_transform(self, rgb, depth):
37+
depth_np = depth
38+
transform = transforms.Compose([
39+
transforms.Crop(130, 10, 240, 1200),
40+
transforms.CenterCrop(self.output_size),
41+
])
42+
rgb_np = transform(rgb)
43+
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
44+
depth_np = np.asfarray(depth_np, dtype='float32')
45+
depth_np = transform(depth_np)
46+
47+
return rgb_np, depth_np
48+

dataloaders/nyu_dataloader.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import dataloaders.transforms as transforms
3+
from dataloaders.dataloader import MyDataloader
4+
5+
iheight, iwidth = 480, 640 # raw image size
6+
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
7+
8+
class NYUDataset(MyDataloader):
9+
def __init__(self, root, type, sparsifier=None, modality='rgb'):
10+
super(NYUDataset, self).__init__(root, type, sparsifier=None, modality='rgb')
11+
self.output_size = (228, 304)
12+
13+
def train_transform(self, rgb, depth):
14+
s = np.random.uniform(1.0, 1.5) # random scaling
15+
depth_np = depth / s
16+
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
17+
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
18+
19+
# perform 1st step of data augmentation
20+
transform = transforms.Compose([
21+
transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow
22+
transforms.Rotate(angle),
23+
transforms.Resize(s),
24+
transforms.CenterCrop(self.output_size),
25+
transforms.HorizontalFlip(do_flip)
26+
])
27+
rgb_np = transform(rgb)
28+
rgb_np = color_jitter(rgb_np) # random color jittering
29+
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
30+
depth_np = transform(depth_np)
31+
32+
return rgb_np, depth_np
33+
34+
def val_transform(self, rgb, depth):
35+
depth_np = depth
36+
transform = transforms.Compose([
37+
transforms.Resize(240.0 / iheight),
38+
transforms.CenterCrop(self.output_size),
39+
])
40+
rgb_np = transform(rgb)
41+
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
42+
depth_np = transform(depth_np)
43+
44+
return rgb_np, depth_np

transforms.py renamed to dataloaders/transforms.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def get_params(img, output_size):
376376
# # randomized cropping
377377
# i = np.random.randint(i-3, i+4)
378378
# j = np.random.randint(j-3, j+4)
379-
379+
380380
return i, j, th, tw
381381

382382
def __call__(self, img):
@@ -514,3 +514,47 @@ def __call__(self, img):
514514
transform = self.get_params(self.brightness, self.contrast,
515515
self.saturation, self.hue)
516516
return np.array(transform(pil))
517+
518+
class Crop(object):
519+
"""Crops the given PIL Image to a rectangular region based on a given
520+
4-tuple defining the left, upper pixel coordinated, hight and width size.
521+
522+
Args:
523+
a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple
524+
"""
525+
526+
def __init__(self, i, j, h, w):
527+
"""
528+
i: Upper pixel coordinate.
529+
j: Left pixel coordinate.
530+
h: Height of the cropped image.
531+
w: Width of the cropped image.
532+
"""
533+
self.i = i
534+
self.j = j
535+
self.h = h
536+
self.w = w
537+
538+
def __call__(self, img):
539+
"""
540+
Args:
541+
img (numpy.ndarray (C x H x W)): Image to be cropped.
542+
Returns:
543+
img (numpy.ndarray (C x H x W)): Cropped image.
544+
"""
545+
546+
i, j, h, w = self.i, self.j, self.h, self.w
547+
548+
if not(_is_numpy_image(img)):
549+
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
550+
if img.ndim == 3:
551+
return img[i:i + h, j:j + w, :]
552+
elif img.ndim == 2:
553+
return img[i:i + h, j:j + w]
554+
else:
555+
raise RuntimeError(
556+
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
557+
558+
def __repr__(self):
559+
return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format(
560+
self.i, self.j, self.h, self.w)

0 commit comments

Comments
 (0)