33import numpy as np
44import torch .utils .data as data
55import h5py
6- import transforms
6+ import dataloaders . transforms as transforms
77
8- IMG_EXTENSIONS = [
9- '.h5' ,
10- ]
8+ IMG_EXTENSIONS = ['.h5' ,]
119
1210def is_image_file (filename ):
1311 return any (filename .endswith (extension ) for extension in IMG_EXTENSIONS )
@@ -22,106 +20,60 @@ def make_dataset(dir, class_to_idx):
2220 images = []
2321 dir = os .path .expanduser (dir )
2422 for target in sorted (os .listdir (dir )):
25- # print(target)
2623 d = os .path .join (dir , target )
2724 if not os .path .isdir (d ):
2825 continue
29-
3026 for root , _ , fnames in sorted (os .walk (d )):
3127 for fname in sorted (fnames ):
3228 if is_image_file (fname ):
3329 path = os .path .join (root , fname )
3430 item = (path , class_to_idx [target ])
3531 images .append (item )
36-
3732 return images
3833
3934def h5_loader (path ):
4035 h5f = h5py .File (path , "r" )
4136 rgb = np .array (h5f ['rgb' ])
4237 rgb = np .transpose (rgb , (1 , 2 , 0 ))
4338 depth = np .array (h5f ['depth' ])
44-
4539 return rgb , depth
4640
47- iheight , iwidth = 480 , 640 # raw image size
48- oheight , owidth = 228 , 304 # image size after pre-processing
49- color_jitter = transforms .ColorJitter (0.4 , 0.4 , 0.4 )
50-
51- def train_transform (rgb , depth ):
52- s = np .random .uniform (1.0 , 1.5 ) # random scaling
53- # print("scale factor s={}".format(s))
54- depth_np = depth / s
55- angle = np .random .uniform (- 5.0 , 5.0 ) # random rotation degrees
56- do_flip = np .random .uniform (0.0 , 1.0 ) < 0.5 # random horizontal flip
57-
58- # perform 1st part of data augmentation
59- transform = transforms .Compose ([
60- transforms .Resize (250.0 / iheight ), # this is for computational efficiency, since rotation is very slow
61- transforms .Rotate (angle ),
62- transforms .Resize (s ),
63- transforms .CenterCrop ((oheight , owidth )),
64- transforms .HorizontalFlip (do_flip )
65- ])
66- rgb_np = transform (rgb )
67-
68- # random color jittering
69- rgb_np = color_jitter (rgb_np )
70-
71- rgb_np = np .asfarray (rgb_np , dtype = 'float' ) / 255
72- depth_np = transform (depth_np )
73-
74- return rgb_np , depth_np
75-
76- def val_transform (rgb , depth ):
77- depth_np = depth
78-
79- # perform 1st part of data augmentation
80- transform = transforms .Compose ([
81- transforms .Resize (240.0 / iheight ),
82- transforms .CenterCrop ((oheight , owidth )),
83- ])
84- rgb_np = transform (rgb )
85- rgb_np = np .asfarray (rgb_np , dtype = 'float' ) / 255
86- depth_np = transform (depth_np )
87-
88- return rgb_np , depth_np
89-
90- def rgb2grayscale (rgb ):
91- return rgb [:,:,0 ] * 0.2989 + rgb [:,:,1 ] * 0.587 + rgb [:,:,2 ] * 0.114
92-
41+ # def rgb2grayscale(rgb):
42+ # return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114
9343
9444to_tensor = transforms .ToTensor ()
9545
96- class NYUDataset (data .Dataset ):
46+ class MyDataloader (data .Dataset ):
9747 modality_names = ['rgb' , 'rgbd' , 'd' ] # , 'g', 'gd'
9848
9949 def __init__ (self , root , type , sparsifier = None , modality = 'rgb' , loader = h5_loader ):
10050 classes , class_to_idx = find_classes (root )
10151 imgs = make_dataset (root , class_to_idx )
102- if len (imgs ) == 0 :
103- raise (RuntimeError ("Found 0 images in subfolders of: " + root + "\n "
104- "Supported image extensions are: " + "," .join (IMG_EXTENSIONS )))
105-
52+ assert len (imgs )> 0 , "Found 0 images in subfolders of: " + root + "\n "
53+ print ("Found {} images in {} folder." .format (len (imgs ), type ))
10654 self .root = root
10755 self .imgs = imgs
10856 self .classes = classes
10957 self .class_to_idx = class_to_idx
11058 if type == 'train' :
111- self .transform = train_transform
59+ self .transform = self . train_transform
11260 elif type == 'val' :
113- self .transform = val_transform
61+ self .transform = self . val_transform
11462 else :
11563 raise (RuntimeError ("Invalid dataset type: " + type + "\n "
11664 "Supported dataset types are: train, val" ))
11765 self .loader = loader
11866 self .sparsifier = sparsifier
11967
120- if modality in self .modality_names :
121- self .modality = modality
122- else :
123- raise (RuntimeError ("Invalid modality type: " + modality + "\n "
124- "Supported dataset types are: " + '' .join (self .modality_names )))
68+ assert (modality in self .modality_names ), "Invalid modality type: " + modality + "\n " + \
69+ "Supported dataset types are: " + '' .join (self .modality_names )
70+ self .modality = modality
71+
72+ def train_transform (self , rgb , depth ):
73+ raise (RuntimeError ("train_transform() is not implemented. " ))
74+
75+ def val_transform (rgb , depth ):
76+ raise (RuntimeError ("val_transform() is not implemented." ))
12577
12678 def create_sparse_depth (self , rgb , depth ):
12779 if self .sparsifier is None :
@@ -134,7 +86,6 @@ def create_sparse_depth(self, rgb, depth):
13486
13587 def create_rgbd (self , rgb , depth ):
13688 sparse_depth = self .create_sparse_depth (rgb , depth )
137- # rgbd = np.dstack((rgb[:,:,0], rgb[:,:,1], rgb[:,:,2], sparse_depth))
13889 rgbd = np .append (rgb , np .expand_dims (sparse_depth , axis = 2 ), axis = 2 )
13990 return rgbd
14091
@@ -150,14 +101,7 @@ def __getraw__(self, index):
150101 rgb , depth = self .loader (path )
151102 return rgb , depth
152103
153- def __get_all_item__ (self , index ):
154- """
155- Args:
156- index (int): Index
157-
158- Returns:
159- tuple: (input_tensor, depth_tensor, input_np, depth_np)
160- """
104+ def __getitem__ (self , index ):
161105 rgb , depth = self .__getraw__ (index )
162106 if self .transform is not None :
163107 rgb_np , depth_np = self .transform (rgb , depth )
@@ -181,19 +125,40 @@ def __get_all_item__(self, index):
181125 depth_tensor = to_tensor (depth_np )
182126 depth_tensor = depth_tensor .unsqueeze (0 )
183127
184- return input_tensor , depth_tensor , input_np , depth_np
185-
186- def __getitem__ (self , index ):
187- """
188- Args:
189- index (int): Index
190-
191- Returns:
192- tuple: (input_tensor, depth_tensor)
193- """
194- input_tensor , depth_tensor , input_np , depth_np = self .__get_all_item__ (index )
195-
196128 return input_tensor , depth_tensor
197129
198130 def __len__ (self ):
199131 return len (self .imgs )
132+
133+ # def __get_all_item__(self, index):
134+ # """
135+ # Args:
136+ # index (int): Index
137+
138+ # Returns:
139+ # tuple: (input_tensor, depth_tensor, input_np, depth_np)
140+ # """
141+ # rgb, depth = self.__getraw__(index)
142+ # if self.transform is not None:
143+ # rgb_np, depth_np = self.transform(rgb, depth)
144+ # else:
145+ # raise(RuntimeError("transform not defined"))
146+
147+ # # color normalization
148+ # # rgb_tensor = normalize_rgb(rgb_tensor)
149+ # # rgb_np = normalize_np(rgb_np)
150+
151+ # if self.modality == 'rgb':
152+ # input_np = rgb_np
153+ # elif self.modality == 'rgbd':
154+ # input_np = self.create_rgbd(rgb_np, depth_np)
155+ # elif self.modality == 'd':
156+ # input_np = self.create_sparse_depth(rgb_np, depth_np)
157+
158+ # input_tensor = to_tensor(input_np)
159+ # while input_tensor.dim() < 3:
160+ # input_tensor = input_tensor.unsqueeze(0)
161+ # depth_tensor = to_tensor(depth_np)
162+ # depth_tensor = depth_tensor.unsqueeze(0)
163+
164+ # return input_tensor, depth_tensor, input_np, depth_np
0 commit comments