1- import os
2- import cv2
1+ import os , sys , warnings , cv2 , argparse
2+
3+ parser = argparse .ArgumentParser (description = 'Interpolation for a pair of images' )
4+ parser .add_argument ('--basedir' , type = str , required = True )
5+ parser .add_argument ('--exp' , default = 2 , type = int )
6+ parser .add_argument ('--output' , default = 'frames_check' , type = str )
7+ parser .add_argument ('--gpu' , default = 0 , type = int )
8+
9+ args = parser .parse_args ()
10+ args .ratio = float (0.0 ) # inference ratio between two images with 0 - 1 range
11+ args .rthreshold = float (0.02 ) # returns image when actual ratio falls in given range threshold
12+ args .rmaxcycles = int (8 ) # limit max number of bisectional cycles
13+
14+ os .environ ['CUDA_VISIBLE_DEVICES' ] = str (args .gpu )
15+
316import torch
4- import argparse
517from torch .nn import functional as F
618from model .RIFE_HDv2 import Model
7- import warnings
19+ import tifffile as tif
20+ import numpy as np
21+ from tqdm import tqdm
22+ from skimage .color import rgb2gray
23+
24+ sys .path .append ('../../..' )
25+ from base .common import load_extracted_prj
26+
827warnings .filterwarnings ("ignore" )
928
1029device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
1332 torch .backends .cudnn .enabled = True
1433 torch .backends .cudnn .benchmark = True
1534
16- parser = argparse .ArgumentParser (description = 'Interpolation for a pair of images' )
17- parser .add_argument ('--img' , dest = 'img' , nargs = 2 , required = True )
18- parser .add_argument ('--exp' , default = 4 , type = int )
19- parser .add_argument ('--ratio' , default = 0 , type = float , help = 'inference ratio between two images with 0 - 1 range' )
20- parser .add_argument ('--rthreshold' , default = 0.02 , type = float , help = 'returns image when actual ratio falls in given range threshold' )
21- parser .add_argument ('--rmaxcycles' , default = 8 , type = int , help = 'limit max number of bisectional cycles' )
22- args = parser .parse_args ()
23-
2435model = Model ()
2536model .load_model (os .path .join (os .path .dirname (os .path .realpath (__file__ )), 'train_log' ), - 1 )
2637model .eval ()
2738model .device ()
2839
29- if args .img [0 ].endswith ('.exr' ) and args .img [1 ].endswith ('.exr' ):
30- img0 = cv2 .imread (args .img [0 ], cv2 .IMREAD_COLOR | cv2 .IMREAD_ANYDEPTH )
31- img1 = cv2 .imread (args .img [1 ], cv2 .IMREAD_COLOR | cv2 .IMREAD_ANYDEPTH )
32- img0 = (torch .tensor (img0 .transpose (2 , 0 , 1 )).to (device )).unsqueeze (0 )
33- img1 = (torch .tensor (img1 .transpose (2 , 0 , 1 )).to (device )).unsqueeze (0 )
40+ # Load in all the images
41+ total_prjs = load_extracted_prj (args .basedir )
42+ save_location = f"{ args .basedir } rife/{ args .output } /"
43+ prj_max = total_prjs .max ()
44+ total_prjs = total_prjs / prj_max
45+ total_prjs = total_prjs * 255.0
3446
35- else :
36- img0 = cv2 .imread (args .img [0 ])
37- img1 = cv2 .imread (args .img [1 ])
38- img0 = (torch .tensor (img0 .transpose (2 , 0 , 1 )).to (device ) / 255. ).unsqueeze (0 )
39- img1 = (torch .tensor (img1 .transpose (2 , 0 , 1 )).to (device ) / 255. ).unsqueeze (0 )
47+ # Iterate over the images
48+ zfill_val = len (str (len (total_prjs ) * 2 ** args .exp ))
49+ current_frame = 0
4050
41- n , c , h , w = img0 .shape
42- ph = ((h - 1 ) // 32 + 1 ) * 32
43- pw = ((w - 1 ) // 32 + 1 ) * 32
44- padding = (0 , pw - w , 0 , ph - h )
45- img0 = F .pad (img0 , padding )
46- img1 = F .pad (img1 , padding )
51+ for iteration in tqdm (range (0 , len (total_prjs )- 1 ), desc = 'Interpolation' ):
52+ # Tripple the image arrays
53+ img0 = np .dstack ((total_prjs [iteration ], total_prjs [iteration ], total_prjs [iteration ]))
54+ img1 = np .dstack ((total_prjs [iteration + 1 ], total_prjs [iteration + 1 ], total_prjs [iteration + 1 ]))
4755
56+ img0 = img0 .astype (np .float32 )
57+ img1 = img1 .astype (np .float32 )
4858
49- if args .ratio :
50- img_list = [img0 ]
51- img0_ratio = 0.0
52- img1_ratio = 1.0
53- if args .ratio <= img0_ratio + args .rthreshold / 2 :
54- middle = img0
55- elif args .ratio >= img1_ratio - args .rthreshold / 2 :
56- middle = img1
57- else :
58- tmp_img0 = img0
59- tmp_img1 = img1
60- for inference_cycle in range (args .rmaxcycles ):
61- middle = model .inference (tmp_img0 , tmp_img1 )
62- middle_ratio = ( img0_ratio + img1_ratio ) / 2
63- if args .ratio - (args .rthreshold / 2 ) <= middle_ratio <= args .ratio + (args .rthreshold / 2 ):
64- break
65- if args .ratio > middle_ratio :
66- tmp_img0 = middle
67- img0_ratio = middle_ratio
68- else :
69- tmp_img1 = middle
70- img1_ratio = middle_ratio
71- img_list .append (middle )
72- img_list .append (img1 )
73- else :
74- img_list = [img0 , img1 ]
75- for i in range (args .exp ):
76- tmp = []
77- for j in range (len (img_list ) - 1 ):
78- mid = model .inference (img_list [j ], img_list [j + 1 ])
79- tmp .append (img_list [j ])
80- tmp .append (mid )
81- tmp .append (img1 )
82- img_list = tmp
59+ img0 = (torch .tensor (img0 .transpose (2 , 0 , 1 )).to (device ) / 255. ).unsqueeze (0 )
60+ img1 = (torch .tensor (img1 .transpose (2 , 0 , 1 )).to (device ) / 255. ).unsqueeze (0 )
61+
62+ n , c , h , w = img0 .shape
63+ ph = ((h - 1 ) // 32 + 1 ) * 32
64+ pw = ((w - 1 ) // 32 + 1 ) * 32
65+ padding = (0 , pw - w , 0 , ph - h )
66+ img0 = F .pad (img0 , padding )
67+ img1 = F .pad (img1 , padding )
8368
84- if not os .path .exists ('output' ):
85- os .mkdir ('output' )
86- for i in range (len (img_list )):
87- if args .img [0 ].endswith ('.exr' ) and args .img [1 ].endswith ('.exr' ):
88- cv2 .imwrite ('output/img{}.exr' .format (i ), (img_list [i ][0 ]).cpu ().numpy ().transpose (1 , 2 , 0 )[:h , :w ], [cv2 .IMWRITE_EXR_TYPE , cv2 .IMWRITE_EXR_TYPE_HALF ])
69+ if args .ratio :
70+ img_list = [img0 ]
71+ img0_ratio = 0.0
72+ img1_ratio = 1.0
73+ if args .ratio <= img0_ratio + args .rthreshold / 2 :
74+ middle = img0
75+ elif args .ratio >= img1_ratio - args .rthreshold / 2 :
76+ middle = img1
77+ else :
78+ tmp_img0 = img0
79+ tmp_img1 = img1
80+ for inference_cycle in range (args .rmaxcycles ):
81+ middle = model .inference (tmp_img0 , tmp_img1 )
82+ middle_ratio = ( img0_ratio + img1_ratio ) / 2
83+ if args .ratio - (args .rthreshold / 2 ) <= middle_ratio <= args .ratio + (args .rthreshold / 2 ):
84+ break
85+ if args .ratio > middle_ratio :
86+ tmp_img0 = middle
87+ img0_ratio = middle_ratio
88+ else :
89+ tmp_img1 = middle
90+ img1_ratio = middle_ratio
91+ img_list .append (middle )
92+ img_list .append (img1 )
8993 else :
90- cv2 .imwrite ('output/img{}.png' .format (i ), (img_list [i ][0 ] * 255 ).byte ().cpu ().numpy ().transpose (1 , 2 , 0 )[:h , :w ])
94+ img_list = [img0 , img1 ]
95+ for i in range (args .exp ):
96+ tmp = []
97+ for j in range (len (img_list ) - 1 ):
98+ mid = model .inference (img_list [j ], img_list [j + 1 ])
99+ tmp .append (img_list [j ])
100+ tmp .append (mid )
101+ tmp .append (img1 )
102+ img_list = tmp
103+
104+ if not os .path .exists (save_location ):
105+ os .mkdir (save_location )
106+ for i in range (len (img_list )):
107+ im2save = (img_list [i ][0 ] * 255 ).cpu ().numpy ().transpose (1 , 2 , 0 )[:h , :w ]
108+ im2save = rgb2gray (im2save )
109+ tif .imsave (f'{ save_location } /img_{ str (current_frame ).zfill (zfill_val )} .tif' , im2save )
110+ current_frame += 1
0 commit comments