Skip to content

Commit 559ab9e

Browse files
author
Clément Pinard
committed
fix imports
won't error if cv2 is not avalaible when not necessary imageio is used instead of scipy
1 parent d95f630 commit 559ab9e

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

datasets/KITTI.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,31 @@
33
import glob
44
from .listdataset import ListDataset
55
from .util import split2list
6-
import cv2
76
import numpy as np
87
import flow_transforms
98

9+
try:
10+
import cv2
11+
except ImportError as e:
12+
import warnings
13+
with warnings.catch_warnings():
14+
warnings.filterwarnings("default", category=ImportWarning)
15+
warnings.warn("failed to load openCV, which is needed"
16+
"for KITTI which uses 16bit PNG images", ImportWarning)
17+
1018
'''
1119
Dataset routines for KITTI_flow, 2012 and 2015.
1220
http://www.cvlibs.net/datasets/kitti/eval_flow.php
1321
The dataset is not very big, you might want to only finetune on it for flownet
1422
EPE are not representative in this dataset because of the sparsity of the GT.
23+
OpenCV is needed to load 16bit png images
1524
'''
1625

1726

1827
def load_flow_from_png(png_path):
19-
flo_file = cv2.imread(png_path,cv2.IMREAD_UNCHANGED)
28+
# The -1 is here to specify not to change the image depth (16bit), and is compatible
29+
# with both OpenCV2 and OpenCV3
30+
flo_file = cv2.imread(png_path, -1)
2031
flo_img = flo_file[:,:,2:0:-1].astype(np.float32)
2132
invalid = (flo_file[:,:,0] == 0)
2233
flo_img = flo_img - 32768
@@ -27,24 +38,25 @@ def load_flow_from_png(png_path):
2738

2839

2940
def make_dataset(dir, split, occ=True):
30-
'''Will search in training folder for folders 'flow_noc' or 'flow_occ' and 'colored_0' (KITTI 2012) or 'image_2' (KITTI 2015) '''
41+
'''Will search in training folder for folders 'flow_noc' or 'flow_occ'
42+
and 'colored_0' (KITTI 2012) or 'image_2' (KITTI 2015) '''
3143
flow_dir = 'flow_occ' if occ else 'flow_noc'
32-
assert(os.path.isdir(os.path.join(dir,flow_dir)))
44+
assert(os.path.isdir(os.path.join(dir, flow_dir)))
3345
img_dir = 'colored_0'
34-
if not os.path.isdir(os.path.join(dir,img_dir)):
46+
if not os.path.isdir(os.path.join(dir, img_dir)):
3547
img_dir = 'image_2'
36-
assert(os.path.isdir(os.path.join(dir,img_dir)))
48+
assert(os.path.isdir(os.path.join(dir, img_dir)))
3749

3850
images = []
39-
for flow_map in glob.iglob(os.path.join(dir,flow_dir,'*.png')):
51+
for flow_map in glob.iglob(os.path.join(dir, flow_dir, '*.png')):
4052
flow_map = os.path.basename(flow_map)
4153
root_filename = flow_map[:-7]
42-
flow_map = os.path.join(flow_dir,flow_map)
43-
img1 = os.path.join(img_dir,root_filename+'_10.png')
44-
img2 = os.path.join(img_dir,root_filename+'_11.png')
45-
if not (os.path.isfile(os.path.join(dir,img1)) or os.path.isfile(os.path.join(dir,img2))):
54+
flow_map = os.path.join(flow_dir, flow_map)
55+
img1 = os.path.join(img_dir, root_filename+'_10.png')
56+
img2 = os.path.join(img_dir, root_filename+'_11.png')
57+
if not (os.path.isfile(os.path.join(dir, img1)) or os.path.isfile(os.path.join(dir, img2))):
4658
continue
47-
images.append([[img1,img2],flow_map])
59+
images.append([[img1, img2], flow_map])
4860

4961
return split2list(images, split, default_split=0.9)
5062

@@ -58,9 +70,13 @@ def KITTI_loader(root,path_imgs, path_flo):
5870
def KITTI_occ(root, transform=None, target_transform=None,
5971
co_transform=None, split=None):
6072
train_list, test_list = make_dataset(root, split, True)
61-
train_dataset = ListDataset(root, train_list, transform, target_transform, co_transform, loader=KITTI_loader)
73+
train_dataset = ListDataset(root, train_list, transform,
74+
target_transform, co_transform,
75+
loader=KITTI_loader)
6276
# All test sample are cropped to lowest possible size of KITTI images
63-
test_dataset = ListDataset(root, test_list, transform, target_transform, flow_transforms.CenterCrop((370,1224)), loader=KITTI_loader)
77+
test_dataset = ListDataset(root, test_list, transform,
78+
target_transform, flow_transforms.CenterCrop((370,1224)),
79+
loader=KITTI_loader)
6480

6581
return train_dataset, test_dataset
6682

datasets/listdataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch.utils.data as data
22
import os
33
import os.path
4-
from scipy.ndimage import imread
4+
from imageio import imread
55
import numpy as np
66

77

0 commit comments

Comments
 (0)