Skip to content

Commit e271d7a

Browse files
author
Clément Pinard
committed
use full res for sparse
1 parent 8e608b6 commit e271d7a

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ def train(train_loader, model, optimizer, epoch, train_writer):
234234

235235
# compute output
236236
output = model(input_var)
237+
if args.sparse:
238+
# Since Target pooling is not very precise when sparse,
239+
# take the highest resolution prediction and upsample it instead of downsampling target
240+
h, w = target_var.size()[-2:]
241+
output = [torch.nn.functional.upsample(output[0], (h,w)), *output[1:]]
237242

238243
loss = multiscaleEPE(output, target_var, weights=args.multiscale_weights, sparse=args.sparse)
239244
flow2_EPE = args.div_flow * realEPE(output[0], target_var, sparse=args.sparse)

0 commit comments

Comments
 (0)