1+ import os
2+ import numpy as np
3+ import shutil
4+ import torch
5+
6+
7+ def save_checkpoint (state , is_best , save_path , filename = 'checkpoint.pth.tar' ):
8+ torch .save (state , os .path .join (save_path ,filename ))
9+ if is_best :
10+ shutil .copyfile (os .path .join (save_path ,filename ), os .path .join (save_path ,'model_best.pth.tar' ))
11+
12+
13+ class AverageMeter (object ):
14+ """Computes and stores the average and current value"""
15+
16+ def __init__ (self ):
17+ self .reset ()
18+
19+ def reset (self ):
20+ self .val = 0
21+ self .avg = 0
22+ self .sum = 0
23+ self .count = 0
24+
25+ def update (self , val , n = 1 ):
26+ self .val = val
27+ self .sum += val * n
28+ self .count += n
29+ self .avg = self .sum / self .count
30+
31+ def __repr__ (self ):
32+ return '{:.3f} ({:.3f})' .format (self .val , self .avg )
33+
34+
35+ def flow2rgb (flow_map , max_value ):
36+ flow_map_np = flow_map .detach ().cpu ().numpy ()
37+ _ , h , w = flow_map_np .shape
38+ flow_map_np [:,(flow_map_np [0 ] == 0 ) & (flow_map_np [1 ] == 0 )] = float ('nan' )
39+ rgb_map = np .ones ((3 ,h ,w )).astype (np .float32 )
40+ if max_value is not None :
41+ normalized_flow_map = flow_map_np / max_value
42+ else :
43+ normalized_flow_map = flow_map_np / (np .abs (flow_map_np ).max ())
44+ rgb_map [0 ] += normalized_flow_map [0 ]
45+ rgb_map [1 ] -= 0.5 * (normalized_flow_map [0 ] + normalized_flow_map [1 ])
46+ rgb_map [2 ] += normalized_flow_map [1 ]
47+ return rgb_map .clip (0 ,1 )
0 commit comments