Skip to content

Commit d95f630

Browse files
author
Clément Pinard
committed
upgrade to pytorch 0.4.1 syntaxe
1 parent d21ba48 commit d95f630

File tree

3 files changed

+50
-42
lines changed

3 files changed

+50
-42
lines changed

main.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55

66
import torch
7+
import torch.nn.functional as F
78
import torch.nn.parallel
89
import torch.backends.cudnn as cudnn
910
import torch.optim
@@ -19,7 +20,6 @@
1920

2021
model_names = sorted(name for name in models.__dict__
2122
if name.islower() and not name.startswith("__"))
22-
2323
dataset_names = sorted(name for name in datasets.__all__)
2424

2525

@@ -82,6 +82,7 @@
8282

8383
best_EPE = -1
8484
n_iter = 0
85+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8586

8687

8788
def main():
@@ -193,7 +194,8 @@ def main():
193194

194195
# evaluate on validation set
195196

196-
EPE = validate(val_loader, model, epoch, output_writers)
197+
with torch.no_grad():
198+
EPE = validate(val_loader, model, epoch, output_writers)
197199
test_writer.add_scalar('mean EPE', EPE, epoch)
198200

199201
if best_EPE < 0:
@@ -227,25 +229,23 @@ def train(train_loader, model, optimizer, epoch, train_writer):
227229
for i, (input, target) in enumerate(train_loader):
228230
# measure data loading time
229231
data_time.update(time.time() - end)
230-
target = target.cuda(async=True)
231-
input = [j.cuda() for j in input]
232-
input_var = torch.autograd.Variable(torch.cat(input,1))
233-
target_var = torch.autograd.Variable(target)
232+
target = target.to(device)
233+
input = torch.cat(input,1).to(device)
234234

235235
# compute output
236-
output = model(input_var)
236+
output = model(input)
237237
if args.sparse:
238238
# Since Target pooling is not very precise when sparse,
239239
# take the highest resolution prediction and upsample it instead of downsampling target
240-
h, w = target_var.size()[-2:]
241-
output = [torch.nn.functional.upsample(output[0], (h,w)), *output[1:]]
240+
h, w = target.size()[-2:]
241+
output = [F.interpolate(output[0], (h,w)), *output[1:]]
242242

243-
loss = multiscaleEPE(output, target_var, weights=args.multiscale_weights, sparse=args.sparse)
244-
flow2_EPE = args.div_flow * realEPE(output[0], target_var, sparse=args.sparse)
243+
loss = multiscaleEPE(output, target, weights=args.multiscale_weights, sparse=args.sparse)
244+
flow2_EPE = args.div_flow * realEPE(output[0], target, sparse=args.sparse)
245245
# record loss and EPE
246-
losses.update(loss.data[0], target.size(0))
247-
train_writer.add_scalar('train_loss', loss.data[0], n_iter)
248-
flow2_EPEs.update(flow2_EPE.data[0], target.size(0))
246+
losses.update(loss.item(), target.size(0))
247+
train_writer.add_scalar('train_loss', loss.item(), n_iter)
248+
flow2_EPEs.update(flow2_EPE.item(), target.size(0))
249249

250250
# compute gradient and do optimization step
251251
optimizer.zero_grad()
@@ -278,26 +278,26 @@ def validate(val_loader, model, epoch, output_writers):
278278

279279
end = time.time()
280280
for i, (input, target) in enumerate(val_loader):
281-
target = target.cuda(async=True)
282-
input_var = torch.autograd.Variable(torch.cat(input,1).cuda(), volatile=True)
283-
target_var = torch.autograd.Variable(target, volatile=True)
281+
target = target.to(device)
282+
input = torch.cat(input,1).to(device)
284283

285284
# compute output
286-
output = model(input_var)
287-
flow2_EPE = args.div_flow*realEPE(output, target_var, sparse=args.sparse)
285+
output = model(input)
286+
flow2_EPE = args.div_flow*realEPE(output, target, sparse=args.sparse)
288287
# record EPE
289-
flow2_EPEs.update(flow2_EPE.data[0], target.size(0))
288+
flow2_EPEs.update(flow2_EPE.item(), target.size(0))
290289

291290
# measure elapsed time
292291
batch_time.update(time.time() - end)
293292
end = time.time()
294293

295294
if i < len(output_writers): # log first output of first batches
296295
if epoch == 0:
297-
output_writers[i].add_image('GroundTruth', flow2rgb(args.div_flow * target[0].cpu().numpy(), max_value=10), 0)
298-
output_writers[i].add_image('Inputs', input[0][0].numpy().transpose(1, 2, 0) + np.array([0.411,0.432,0.45]), 0)
299-
output_writers[i].add_image('Inputs', input[1][0].numpy().transpose(1, 2, 0) + np.array([0.411,0.432,0.45]), 1)
300-
output_writers[i].add_image('FlowNet Outputs', flow2rgb(args.div_flow * output.data[0].cpu().numpy(), max_value=10), epoch)
296+
mean_values = torch.tensor([0.411,0.432,0.45], dtype=input.dtype).view(3,1,1)
297+
output_writers[i].add_image('GroundTruth', flow2rgb(args.div_flow * target[0], max_value=10), 0)
298+
output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)
299+
output_writers[i].add_image('Inputs', (input[0,3:].cpu() + mean_values).clamp(0,1), 1)
300+
output_writers[i].add_image('FlowNet Outputs', flow2rgb(args.div_flow * output[0], max_value=10), epoch)
301301

302302
if i % args.print_freq == 0:
303303
print('Test: [{0}/{1}]\t Time {2}\t EPE {3}'
@@ -337,17 +337,17 @@ def __repr__(self):
337337

338338

339339
def flow2rgb(flow_map, max_value):
340-
global args
341-
_, h, w = flow_map.shape
342-
flow_map[:,(flow_map[0] == 0) & (flow_map[1] == 0)] = float('nan')
343-
rgb_map = np.ones((h,w,3)).astype(np.float32)
340+
flow_map_np = flow_map.detach().cpu().numpy()
341+
_, h, w = flow_map_np.shape
342+
flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
343+
rgb_map = np.ones((3,h,w)).astype(np.float32)
344344
if max_value is not None:
345-
normalized_flow_map = flow_map / max_value
345+
normalized_flow_map = flow_map_np / max_value
346346
else:
347-
normalized_flow_map = flow_map / (np.abs(flow_map).max())
348-
rgb_map[:,:,0] += normalized_flow_map[0]
349-
rgb_map[:,:,1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
350-
rgb_map[:,:,2] += normalized_flow_map[1]
347+
normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
348+
rgb_map[0] += normalized_flow_map[0]
349+
rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
350+
rgb_map[2] += normalized_flow_map[1]
351351
return rgb_map.clip(0,1)
352352

353353

models/FlowNetS.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def __init__(self,batchNorm=True):
7575

7676
for m in self.modules():
7777
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
78-
kaiming_normal(m.weight.data)
78+
kaiming_normal_(m.weight, 0.1)
7979
if m.bias is not None:
80-
m.bias.data.zero_()
80+
constant_(m.bias, 0)
8181
elif isinstance(m, nn.BatchNorm2d):
82-
m.weight.data.fill_(1)
83-
m.bias.data.zero_()
82+
constant_(m.weight, 1)
83+
constant_(m.bias, 0)
8484

8585
def forward(self, x):
8686
out_conv2 = self.conv2(self.conv1(x))

multiscaleloss.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
import torch.nn as nn
2+
import torch.nn.functional as F
33

44

55
def EPE(input_flow, target_flow, sparse=False, mean=True):
@@ -9,17 +9,25 @@ def EPE(input_flow, target_flow, sparse=False, mean=True):
99
# invalid flow is defined with both flow coordinates to be exactly 0
1010
mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
1111

12-
EPE_map = EPE_map[~mask.data]
12+
EPE_map = EPE_map[~mask]
1313
if mean:
1414
return EPE_map.mean()
1515
else:
1616
return EPE_map.sum()/batch_size
1717

1818

1919
def sparse_max_pool(input, size):
20+
'''Downsample the input by considering 0 values as invalid.
21+
22+
Unfortunately, no generic interpolation mode can resize a sparse map correctly,
23+
the strategy here is to use max pooling for positive values and "min pooling"
24+
for negative values, the two results are then summed.
25+
This technique allows sparsity to be minized, contrary to nearest interpolation,
26+
which could potentially lose information for isolated data points.'''
27+
2028
positive = (input > 0).float()
2129
negative = (input < 0).float()
22-
output = nn.functional.adaptive_max_pool2d(input * positive, size) - nn.functional.adaptive_max_pool2d(-input * negative, size)
30+
output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
2331
return output
2432

2533

@@ -31,7 +39,7 @@ def one_scale(output, target, sparse):
3139
if sparse:
3240
target_scaled = sparse_max_pool(target, (h, w))
3341
else:
34-
target_scaled = nn.functional.adaptive_avg_pool2d(target, (h, w))
42+
target_scaled = F.interpolate(target, (h, w), mode='area')
3543
return EPE(output, target_scaled, sparse, mean=False)
3644

3745
if type(network_output) not in [tuple, list]:
@@ -48,5 +56,5 @@ def one_scale(output, target, sparse):
4856

4957
def realEPE(output, target, sparse=False):
5058
b, _, h, w = target.size()
51-
upsampled_output = nn.functional.upsample(output, size=(h,w), mode='bilinear')
59+
upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False)
5260
return EPE(upsampled_output, target, sparse, mean=True)

0 commit comments

Comments
 (0)