Skip to content

Commit 02d58e0

Browse files
author
Clément Pinard
committed
update run_inference
* now uses imageio for img reading/writing * does not need to import the whole main script, just the util * more options, to allow for different values to output
1 parent 678ac41 commit 02d58e0

File tree

3 files changed

+48
-61
lines changed

3 files changed

+48
-61
lines changed

README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ It has not been tested for multiple GPU, but it should work just as in original
88

99
The code provides a training example, using [the flying chair dataset](http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) , with data augmentation. An implementation for [Scene Flow Datasets](http://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) may be added in the future.
1010

11-
Two neural network models are currently provided :
11+
Two neural network models are currently provided, along with their batch norm variation (experimental) :
1212

1313
- **FlowNetS**
1414
- **FlowNetSBN**
@@ -22,12 +22,12 @@ Thanks to [Kaixhin](https://github.com/Kaixhin) you can download a pretrained ve
2222
Directly feed the downloaded Network to the script, you don't need to uncompress it even if your desktop environment tells you so.
2323

2424
### Note on networks from caffe
25-
These networks expect a BGR input in range `[-0.5,0.5]` (compared to RGB in pytorch). However, BGR order is not very important.
25+
These networks expect a BGR input (compared to RGB in pytorch). However, BGR order is not very important.
2626

2727
## Prerequisite
2828

2929
```
30-
pytorch >= 0.4.1
30+
pytorch >= 1.0.1
3131
tensorboard-pytorch
3232
tensorboardX >= 1.4
3333
spatial-correlation-sampler>=0.0.8
@@ -88,6 +88,22 @@ Exact code for Optical Flow -> Color map can be found [here](main.py#L321)
8888
| <img src='images/input_2.gif' width=256> | <img src='images/pred_2.png' width=256> | <img src='images/GT_2.png' width=256> |
8989
| <img src='images/input_3.gif' width=256> | <img src='images/pred_3.png' width=256> | <img src='images/GT_3.png' width=256> |
9090

91+
## Running inference on a set of image pairs
92+
93+
If you need to run the network on your images, you can download a pretrained network [here](https://drive.google.com/open?id=0B5EC7HMbyk3CbjFPb0RuODI3NmM) and launch the inference script on your folder of image pairs.
94+
95+
Your folder needs to have all the images pairs in the same location, with the name pattern
96+
```
97+
{image_name}1.{ext}
98+
{image_name}2.{ext}
99+
```
100+
101+
```bash
102+
python3 run_inference.py /path/to/images/folder /path/to/pretrained
103+
```
104+
105+
As for the `main.py` script, a help menu is available for additional options.
106+
91107
## Note on transform functions
92108

93109
In order to have coherent transformations between inputs and target, we must define new transformations that take both input and target, as a new random variable is defined each time a random transformation is called.

main.py

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import os
3-
import shutil
43
import time
54

65
import torch
@@ -16,13 +15,12 @@
1615
from multiscaleloss import multiscaleEPE, realEPE
1716
import datetime
1817
from tensorboardX import SummaryWriter
19-
import numpy as np
18+
from util import flow2rgb, AverageMeter, save_checkpoint
2019

2120
model_names = sorted(name for name in models.__dict__
2221
if name.islower() and not name.startswith("__"))
2322
dataset_names = sorted(name for name in datasets.__all__)
2423

25-
2624
parser = argparse.ArgumentParser(description='PyTorch FlowNet Training on several datasets',
2725
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
2826
parser.add_argument('data', metavar='DIR',
@@ -86,7 +84,7 @@
8684

8785

8886
def main():
89-
global args, best_EPE, save_path
87+
global args, best_EPE
9088
args = parser.parse_args()
9189
save_path = '{},{},{}epochs{},b{},lr{}'.format(
9290
args.arch,
@@ -209,7 +207,7 @@ def main():
209207
'state_dict': model.module.state_dict(),
210208
'best_EPE': best_EPE,
211209
'div_flow': args.div_flow
212-
}, is_best)
210+
}, is_best, save_path)
213211

214212

215213
def train(train_loader, model, optimizer, epoch, train_writer):
@@ -308,48 +306,5 @@ def validate(val_loader, model, epoch, output_writers):
308306
return flow2_EPEs.avg
309307

310308

311-
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
312-
torch.save(state, os.path.join(save_path,filename))
313-
if is_best:
314-
shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))
315-
316-
317-
class AverageMeter(object):
318-
"""Computes and stores the average and current value"""
319-
320-
def __init__(self):
321-
self.reset()
322-
323-
def reset(self):
324-
self.val = 0
325-
self.avg = 0
326-
self.sum = 0
327-
self.count = 0
328-
329-
def update(self, val, n=1):
330-
self.val = val
331-
self.sum += val * n
332-
self.count += n
333-
self.avg = self.sum / self.count
334-
335-
def __repr__(self):
336-
return '{:.3f} ({:.3f})'.format(self.val, self.avg)
337-
338-
339-
def flow2rgb(flow_map, max_value):
340-
flow_map_np = flow_map.detach().cpu().numpy()
341-
_, h, w = flow_map_np.shape
342-
flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
343-
rgb_map = np.ones((3,h,w)).astype(np.float32)
344-
if max_value is not None:
345-
normalized_flow_map = flow_map_np / max_value
346-
else:
347-
normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
348-
rgb_map[0] += normalized_flow_map[0]
349-
rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
350-
rgb_map[2] += normalized_flow_map[1]
351-
return rgb_map.clip(0,1)
352-
353-
354309
if __name__ == '__main__':
355310
main()

run_inference.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import torch.nn.functional as F
77
import models
88
from tqdm import tqdm
9+
910
import torchvision.transforms as transforms
1011
import flow_transforms
11-
from scipy.ndimage import imread
12-
from scipy.misc import imsave
12+
from imageio import imread, imwrite
1313
import numpy as np
14-
from main import flow2rgb
14+
from util import flow2rgb
1515

1616
model_names = sorted(name for name in models.__dict__
1717
if name.islower() and not name.startswith("__"))
@@ -22,11 +22,15 @@
2222
parser.add_argument('data', metavar='DIR',
2323
help='path to images folder, image names must match \'[name]0.[ext]\' and \'[name]1.[ext]\'')
2424
parser.add_argument('pretrained', metavar='PTH', help='path to pre-trained model')
25-
parser.add_argument('--output', metavar='DIR', default=None,
25+
parser.add_argument('--output', '-o', metavar='DIR', default=None,
2626
help='path to output folder. If not set, will be created in data folder')
27+
parser.add_argument('--output-value', '-v', metavar='VAL', choices=['raw', 'vis', 'both'], default='both',
28+
help='which value to output, between raw input (as a npy file) and color vizualisation (as an image file).'
29+
' If not set, will output both')
2730
parser.add_argument('--div-flow', default=20, type=float,
2831
help='value by which flow will be divided. overwritten if stored in pretrained file')
29-
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob")
32+
parser.add_argument("--img-exts", metavar='EXT', default=['png', 'jpg', 'bmp', 'ppm'], nargs='*', type=str,
33+
help="images extensions to glob")
3034
parser.add_argument('--max_flow', default=None, type=float,
3135
help='max flow value. Flow map color is saturated above this value. If not set, will use flow map\'s max value')
3236
parser.add_argument('--upsampling', '-u', choices=['nearest', 'bilinear'], default=None, help='if not set, will output FlowNet raw input,'
@@ -40,6 +44,14 @@
4044
def main():
4145
global args, save_path
4246
args = parser.parse_args()
47+
48+
if args.output_value == 'both':
49+
output_string = "raw output and RGB visualization"
50+
elif args.output_value == 'raw':
51+
output_string = "raw output"
52+
elif args.output_value == 'vis':
53+
output_string = "RGB visualization"
54+
print("=> will save " + output_string)
4355
data_dir = Path(args.data)
4456
print("=> fetching img pairs in '{}'".format(args.data))
4557
if args.output is None:
@@ -58,9 +70,9 @@ def main():
5870

5971
img_pairs = []
6072
for ext in args.img_exts:
61-
test_files = data_dir.files('*0.{}'.format(ext))
73+
test_files = data_dir.files('*1.{}'.format(ext))
6274
for file in test_files:
63-
img_pair = file.parent / (file.namebase[:-1] + '1.{}'.format(ext))
75+
img_pair = file.parent / (file.namebase[:-1] + '2.{}'.format(ext))
6476
if img_pair.isfile():
6577
img_pairs.append([file, img_pair])
6678

@@ -92,9 +104,13 @@ def main():
92104
if args.upsampling is not None:
93105
output = F.interpolate(output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False)
94106
for suffix, flow_output in zip(['flow', 'inv_flow'], output):
95-
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
96-
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
97-
imsave(save_path/'{}{}.png'.format(img1_file.namebase[:-1], suffix), to_save)
107+
filename = save_path/'{}{}'.format(img1_file.namebase[:-1], suffix)
108+
if args.output_value in['vis', 'both']:
109+
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
110+
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
111+
imwrite(filename + '.png', to_save)
112+
if args.output_value in ['raw', 'both']:
113+
np.save(filename + 'npy', flow_output.cpu().numpy())
98114

99115

100116
if __name__ == '__main__':

0 commit comments

Comments
 (0)