33
44import torch
55import torch .backends .cudnn as cudnn
6+ import torch .nn .functional as F
67import models
78from tqdm import tqdm
89import torchvision .transforms as transforms
3233 'which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling' )
3334parser .add_argument ('--bidirectional' , action = 'store_true' , help = 'if set, will output invert flow (from 1 to 0) along with regular flow' )
3435
36+ device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
37+
38+
3539@torch .no_grad ()
3640def main ():
3741 global args , save_path
3842 args = parser .parse_args ()
3943 data_dir = Path (args .data )
4044 print ("=> fetching img pairs in '{}'" .format (args .data ))
41- save_path = data_dir / 'flow'
45+ if args .output is None :
46+ save_path = data_dir / 'flow'
47+ else :
48+ save_path = Path (args .output )
4249 print ('=> will save everything to {}' .format (save_path ))
4350 save_path .makedirs_p ()
4451
@@ -61,7 +68,7 @@ def main():
6168 # create model
6269 network_data = torch .load (args .pretrained )
6370 print ("=> using pre-trained model '{}'" .format (network_data ['arch' ]))
64- model = models .__dict__ [network_data ['arch' ]](network_data ).cuda ( )
71+ model = models .__dict__ [network_data ['arch' ]](network_data ).to ( device )
6572 model .eval ()
6673 cudnn .benchmark = True
6774
@@ -72,19 +79,20 @@ def main():
7279
7380 img1 = input_transform (imread (img1_file ))
7481 img2 = input_transform (imread (img2_file ))
75- input_var = torch .tensor ( torch . cat ([img1 , img2 ]). cuda () ).unsqueeze (0 )
82+ input_var = torch .cat ([img1 , img2 ]).unsqueeze (0 )
7683
7784 if args .bidirectional :
7885 # feed inverted pair along with normal pair
79- inverted_input_var = torch .tensor ( torch . cat ([img2 , img1 ], 0 ). cuda () ).unsqueeze (0 )
86+ inverted_input_var = torch .cat ([img2 , img1 ]).unsqueeze (0 )
8087 input_var = torch .cat ([input_var , inverted_input_var ])
8188
89+ input_var = input_var .to (device )
8290 # compute output
8391 output = model (input_var )
8492 if args .upsampling is not None :
85- output = torch . nn . functional . upsample (output , size = img1 .size ()[- 2 :], mode = args .upsampling )
86- for suffix , flow_output in zip (['flow' , 'inv_flow' ], output . data . cpu () ):
87- rgb_flow = flow2rgb (args .div_flow * flow_output . numpy () , max_value = args .max_flow )
93+ output = F . interpolate (output , size = img1 .size ()[- 2 :], mode = args .upsampling , align_corners = False )
94+ for suffix , flow_output in zip (['flow' , 'inv_flow' ], output ):
95+ rgb_flow = flow2rgb (args .div_flow * flow_output , max_value = args .max_flow )
8896 to_save = (rgb_flow * 255 ).astype (np .uint8 ).transpose (1 ,2 ,0 )
8997 imsave (save_path / '{}{}.png' .format (img1_file .namebase [:- 1 ], suffix ), to_save )
9098
0 commit comments