Skip to content

Commit 6d91c71

Browse files
author
Fangchang Ma
committed
Merge branch 'timethy-feature/sparse-to-dense'
2 parents cbf1f46 + 12055e0 commit 6d91c71

File tree

4 files changed

+185
-43
lines changed

4 files changed

+185
-43
lines changed

dense_to_sparse.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
import cv2
3+
4+
5+
def rgb2grayscale(rgb):
6+
return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114
7+
8+
9+
class DenseToSparse:
10+
def __init__(self):
11+
pass
12+
13+
def dense_to_sparse(self, rgb, depth):
14+
pass
15+
16+
def __repr__(self):
17+
pass
18+
19+
20+
class UniformSampling(DenseToSparse):
21+
name = "uar"
22+
23+
def __init__(self, num_samples, max_depth=np.inf):
24+
DenseToSparse.__init__(self)
25+
self.num_samples = num_samples
26+
self.max_depth = max_depth
27+
28+
def __repr__(self):
29+
return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth)
30+
31+
def dense_to_sparse(self, rgb, depth):
32+
"""
33+
Samples pixels with `num_samples`/#pixels probability in `depth`.
34+
Only pixels with a maximum depth of `max_depth` are considered.
35+
If no `max_depth` is given, samples in all pixels
36+
"""
37+
if self.max_depth is not np.inf:
38+
mask_keep = depth <= self.max_depth
39+
n_keep = np.count_nonzero(mask_keep)
40+
if n_keep == 0:
41+
return mask_keep
42+
else:
43+
prob = float(self.num_samples) / n_keep
44+
return np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob)
45+
else:
46+
prob = float(self.num_samples) / depth.size
47+
return np.random.uniform(0, 1, depth.shape) < prob
48+
49+
50+
class SimulatedStereo(DenseToSparse):
51+
name = "sim_stereo"
52+
53+
def __init__(self, num_samples, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1):
54+
DenseToSparse.__init__(self)
55+
self.num_samples = num_samples
56+
self.max_depth = max_depth
57+
self.dilate_kernel = dilate_kernel
58+
self.dilate_iterations = dilate_iterations
59+
60+
def __repr__(self):
61+
return "%s{ns=%d,md=%f,dil=%d.%d}" % \
62+
(self.name, self.num_samples, self.max_depth, self.dilate_kernel, self.dilate_iterations)
63+
64+
# We do not use cv2.Canny, since that applies non max suppression
65+
# So we simply do
66+
# RGB to intensitities
67+
# Smooth with gaussian
68+
# Take simple sobel gradients
69+
# Threshold the edge gradient
70+
# Dilatate
71+
def dense_to_sparse(self, rgb, depth):
72+
gray = rgb2grayscale(rgb)
73+
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
74+
gx = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=5)
75+
gy = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=5)
76+
77+
depth_mask = np.bitwise_and(depth != 0.0, depth <= self.max_depth)
78+
79+
edge_fraction = float(self.num_samples) / np.size(depth)
80+
81+
mag = cv2.magnitude(gx, gy)
82+
min_mag = np.percentile(mag[depth_mask], 100 * (1.0 - edge_fraction))
83+
mag_mask = mag >= min_mag
84+
85+
if self.dilate_iterations >= 0:
86+
kernel = np.ones((self.dilate_kernel, self.dilate_kernel), dtype=np.uint8)
87+
cv2.dilate(mag_mask.astype(np.uint8), kernel, iterations=self.dilate_iterations)
88+
89+
mask = np.bitwise_and(mag_mask, depth_mask)
90+
return mask

main.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
import sys
66
import csv
7+
import numpy as np
78

89
import torch
910
import torch.nn as nn
@@ -15,12 +16,14 @@
1516
from nyu_dataloader import NYUDataset
1617
from models import Decoder, ResNet
1718
from metrics import AverageMeter, Result
19+
from dense_to_sparse import UniformSampling, SimulatedStereo
1820
import criteria
1921
import utils
2022

2123
model_names = ['resnet18', 'resnet50']
2224
loss_names = ['l1', 'l2']
23-
data_names = ['NYUDataset']
25+
data_names = ['nyudepthv2']
26+
sparsifier_names = [x.name for x in [UniformSampling, SimulatedStereo]]
2427
decoder_names = Decoder.names
2528
modality_names = NYUDataset.modality_names
2629

@@ -46,6 +49,13 @@
4649
' (default: rgb)')
4750
parser.add_argument('-s', '--num-samples', default=0, type=int, metavar='N',
4851
help='number of sparse depth samples (default: 0)')
52+
parser.add_argument('--max-depth', default=-1.0, type=float, metavar='D',
53+
help='cut-off depth of sparsifier, negative values means infinity (default: inf [m])')
54+
parser.add_argument('--sparsifier', metavar='SPARSIFIER', default=UniformSampling.name,
55+
choices=sparsifier_names,
56+
help='sparsifier: ' +
57+
' | '.join(sparsifier_names) +
58+
' (default: ' + UniformSampling.name + ')')
4959
parser.add_argument('--decoder', '-d', metavar='DECODER', default='deconv2',
5060
choices=decoder_names,
5161
help='decoder: ' +
@@ -88,15 +98,24 @@
8898
def main():
8999
global args, best_result, output_directory, train_csv, test_csv
90100
args = parser.parse_args()
91-
args.data = os.path.join('data', args.data)
92101
if args.modality == 'rgb' and args.num_samples != 0:
93102
print("number of samples is forced to be 0 when input modality is rgb")
94103
args.num_samples = 0
95-
104+
if args.modality == 'rgb' and args.max_depth != 0.0:
105+
print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
106+
args.max_depth = 0.0
107+
108+
sparsifier = None
109+
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
110+
if args.sparsifier == UniformSampling.name:
111+
sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
112+
elif args.sparsifier == SimulatedStereo.name:
113+
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
114+
96115
# create results folder, if not already exists
97116
output_directory = os.path.join('results',
98-
'NYUDataset.modality={}.nsample={}.arch={}.decoder={}.criterion={}.lr={}.bs={}'.
99-
format(args.modality, args.num_samples, args.arch, args.decoder, args.criterion, args.lr, args.batch_size))
117+
'{}.sparsifier={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}'.
118+
format(args.data, sparsifier, args.modality, args.arch, args.decoder, args.criterion, args.lr, args.batch_size))
100119
if not os.path.exists(output_directory):
101120
os.makedirs(output_directory)
102121
train_csv = os.path.join(output_directory, 'train.csv')
@@ -112,19 +131,19 @@ def main():
112131

113132
# Data loading code
114133
print("=> creating data loaders ...")
115-
traindir = os.path.join(args.data, 'train')
116-
valdir = os.path.join(args.data, 'val')
134+
traindir = os.path.join('data', args.data, 'train')
135+
valdir = os.path.join('data', args.data, 'val')
117136

118-
train_dataset = NYUDataset(traindir, type='train',
119-
modality=args.modality, num_samples=args.num_samples)
137+
train_dataset = NYUDataset(traindir, type='train',
138+
modality=args.modality, sparsifier=sparsifier)
120139
train_loader = torch.utils.data.DataLoader(
121140
train_dataset, batch_size=args.batch_size, shuffle=True,
122141
num_workers=args.workers, pin_memory=True, sampler=None)
123142

124143
# set batch size to be 1 for validation
125-
val_dataset = NYUDataset(valdir, type='val',
126-
modality=args.modality, num_samples=args.num_samples)
127-
val_loader = torch.utils.data.DataLoader(val_dataset,
144+
val_dataset = NYUDataset(valdir, type='val',
145+
modality=args.modality, sparsifier=sparsifier)
146+
val_loader = torch.utils.data.DataLoader(val_dataset,
128147
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
129148

130149
print("=> data loaders created.")
@@ -192,7 +211,7 @@ def main():
192211
adjust_learning_rate(optimizer, epoch)
193212

194213
# train for one epoch
195-
train(train_loader, model, criterion, optimizer, epoch)
214+
# train(train_loader, model, criterion, optimizer, epoch)
196215

197216
# evaluate on validation set
198217
result, img_merge = validate(val_loader, model, epoch)
@@ -306,11 +325,18 @@ def validate(val_loader, model, epoch, write_to_file=True):
306325
rgb = input
307326
elif args.modality == 'rgbd':
308327
rgb = input[:,:3,:,:]
328+
depth = input[:,3:,:,:]
309329

310330
if i == 0:
311-
img_merge = utils.merge_into_row(rgb, target, depth_pred)
331+
if args.modality == 'rgbd':
332+
img_merge = utils.merge_into_row_with_gt(rgb, depth, target, depth_pred)
333+
else:
334+
img_merge = utils.merge_into_row(rgb, target, depth_pred)
312335
elif (i < 8*skip) and (i % skip == 0):
313-
row = utils.merge_into_row(rgb, target, depth_pred)
336+
if args.modality == 'rgbd':
337+
row = utils.merge_into_row_with_gt(rgb, depth, target, depth_pred)
338+
else:
339+
row = utils.merge_into_row(rgb, target, depth_pred)
314340
img_merge = utils.add_row(img_merge, row)
315341
elif i == 8*skip:
316342
filename = output_directory + '/comparison_' + str(epoch) + '.png'

nyu_dataloader.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def rgb2grayscale(rgb):
9696
class NYUDataset(data.Dataset):
9797
modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd'
9898

99-
def __init__(self, root, type, modality='rgb', num_samples=0, loader=h5_loader):
99+
def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
100100
classes, class_to_idx = find_classes(root)
101101
imgs = make_dataset(root, class_to_idx)
102102
if len(imgs) == 0:
@@ -115,28 +115,25 @@ def __init__(self, root, type, modality='rgb', num_samples=0, loader=h5_loader):
115115
raise (RuntimeError("Invalid dataset type: " + type + "\n"
116116
"Supported dataset types are: train, val"))
117117
self.loader = loader
118+
self.sparsifier = sparsifier
118119

119120
if modality in self.modality_names:
120121
self.modality = modality
121-
if modality in ['rgbd', 'd', 'gd']:
122-
if num_samples <= 0:
123-
raise (RuntimeError("Invalid number of samples: {}\n".format(num_samples)))
124-
self.num_samples = num_samples
125-
else:
126-
self.num_samples = 0
127122
else:
128123
raise (RuntimeError("Invalid modality type: " + modality + "\n"
129124
"Supported dataset types are: " + ''.join(self.modality_names)))
130125

131-
def create_sparse_depth(self, depth, num_samples):
132-
prob = float(num_samples) / depth.size
133-
mask_keep = np.random.uniform(0, 1, depth.shape) < prob
134-
sparse_depth = np.zeros(depth.shape)
135-
sparse_depth[mask_keep] = depth[mask_keep]
136-
return sparse_depth
126+
def create_sparse_depth(self, rgb, depth):
127+
if self.sparsifier is None:
128+
return depth
129+
else:
130+
mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
131+
sparse_depth = np.zeros(depth.shape)
132+
sparse_depth[mask_keep] = depth[mask_keep]
133+
return sparse_depth
137134

138-
def create_rgbd(self, rgb, depth, num_samples):
139-
sparse_depth = self.create_sparse_depth(depth, num_samples)
135+
def create_rgbd(self, rgb, depth):
136+
sparse_depth = self.create_sparse_depth(rgb, depth)
140137
# rgbd = np.dstack((rgb[:,:,0], rgb[:,:,1], rgb[:,:,2], sparse_depth))
141138
rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
142139
return rgbd
@@ -170,13 +167,13 @@ def __get_all_item__(self, index):
170167
# color normalization
171168
# rgb_tensor = normalize_rgb(rgb_tensor)
172169
# rgb_np = normalize_np(rgb_np)
173-
170+
174171
if self.modality == 'rgb':
175172
input_np = rgb_np
176173
elif self.modality == 'rgbd':
177-
input_np = self.create_rgbd(rgb_np, depth_np, self.num_samples)
174+
input_np = self.create_rgbd(rgb_np, depth_np)
178175
elif self.modality == 'd':
179-
input_np = self.create_sparse_depth(depth_np, self.num_samples)
176+
input_np = self.create_sparse_depth(rgb_np, depth_np)
180177

181178
input_tensor = to_tensor(input_np)
182179
while input_tensor.dim() < 3:

utils.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,53 @@
22
import matplotlib.pyplot as plt
33
from PIL import Image
44

5-
cmap = plt.cm.jet
5+
cmap = plt.cm.viridis
66

7-
def merge_into_row(input, target, depth_pred):
7+
8+
def colored_depthmap(depth, d_min=None, d_max=None):
9+
if d_min is None:
10+
d_min = np.min(depth)
11+
if d_max is None:
12+
d_max = np.max(depth)
13+
depth_relative = (depth - d_min) / (d_max - d_min)
14+
return 255 * cmap(depth_relative)[:,:,:3] # H, W, C
15+
16+
17+
def merge_into_row(input, depth_target, depth_pred):
818
rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
9-
depth = np.squeeze(target.cpu().numpy())
10-
depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
11-
depth = 255 * cmap(depth)[:,:,:3] # H, W, C
12-
pred = np.squeeze(depth_pred.data.cpu().numpy())
13-
pred = (pred - np.min(pred)) / (np.max(pred) - np.min(pred))
14-
pred = 255 * cmap(pred)[:,:,:3] # H, W, C
15-
img_merge = np.hstack([rgb, depth, pred])
19+
depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
20+
depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
21+
22+
d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu))
23+
d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu))
24+
depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
25+
depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
26+
img_merge = np.hstack([rgb, depth_target_col, depth_pred_col])
1627

17-
# img_merge.save(output_directory + '/comparison_' + str(epoch) + '.png')
1828
return img_merge
1929

30+
31+
def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred):
32+
rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
33+
depth_input_cpu = np.squeeze(depth_input.cpu().numpy())
34+
depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
35+
depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
36+
37+
d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu))
38+
d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu))
39+
depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max)
40+
depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
41+
depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
42+
43+
img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col])
44+
45+
return img_merge
46+
47+
2048
def add_row(img_merge, row):
2149
return np.vstack([img_merge, row])
2250

51+
2352
def save_image(img_merge, filename):
2453
img_merge = Image.fromarray(img_merge.astype('uint8'))
2554
img_merge.save(filename)

0 commit comments

Comments
 (0)