Skip to content

Commit c37a8db

Browse files
author
Clément Pinard
committed
add util file
1 parent 5d8814f commit c37a8db

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

util.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
import numpy as np
3+
import shutil
4+
import torch
5+
6+
7+
def save_checkpoint(state, is_best, save_path, filename='checkpoint.pth.tar'):
8+
torch.save(state, os.path.join(save_path,filename))
9+
if is_best:
10+
shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))
11+
12+
13+
class AverageMeter(object):
14+
"""Computes and stores the average and current value"""
15+
16+
def __init__(self):
17+
self.reset()
18+
19+
def reset(self):
20+
self.val = 0
21+
self.avg = 0
22+
self.sum = 0
23+
self.count = 0
24+
25+
def update(self, val, n=1):
26+
self.val = val
27+
self.sum += val * n
28+
self.count += n
29+
self.avg = self.sum / self.count
30+
31+
def __repr__(self):
32+
return '{:.3f} ({:.3f})'.format(self.val, self.avg)
33+
34+
35+
def flow2rgb(flow_map, max_value):
36+
flow_map_np = flow_map.detach().cpu().numpy()
37+
_, h, w = flow_map_np.shape
38+
flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
39+
rgb_map = np.ones((3,h,w)).astype(np.float32)
40+
if max_value is not None:
41+
normalized_flow_map = flow_map_np / max_value
42+
else:
43+
normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
44+
rgb_map[0] += normalized_flow_map[0]
45+
rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
46+
rgb_map[2] += normalized_flow_map[1]
47+
return rgb_map.clip(0,1)

0 commit comments

Comments
 (0)