1+ from __future__ import annotations
12import sys
2- from math import floor
33from pathlib import Path
44
55from PIL import Image
1313
1414
1515class RIFE :
16- def __init__ (self , gpuid : int = - 1 , model : str = "rife-HD" , tta_mode : bool = False , uhd_mode : bool = False , num_threads : int = 1 ):
16+ def __init__ (self ,
17+ gpuid : int = - 1 ,
18+ model : str = "rife-HD" ,
19+ scale : int = 2 ,
20+ tta_mode : bool = False ,
21+ uhd_mode : bool = False ,
22+ num_threads : int = 1 ):
1723 rife_v2 = "rife-v2" in model
1824 self .model = model
25+
26+ if (scale & (scale - 1 )) == 0 :
27+ self .scale = scale
28+ else :
29+ raise ValueError ("scale should be powers of 2" )
30+
1931 self ._raw_rife = raw .RIFEWrapper (gpuid , tta_mode , uhd_mode , num_threads , rife_v2 )
2032 self .load ()
2133
@@ -42,7 +54,24 @@ def load(self, model_dir: str = ""):
4254 else :
4355 raise FileNotFoundError (f"{ model_dir } not found" )
4456
45- def process (self , im0 : Image , im1 : Image ) -> Image :
57+ def process (self , im0 : Image , im1 : Image ) -> list [Image ]:
58+ """
59+ interpolate frames between im0 and im1
60+ :param im0: First frame
61+ :param im1: Second frame
62+ :return: the interpolation frames between im0 and im1
63+ """
64+ def _proc (im0 : Image , im1 : Image , level ) -> list [Image ]:
65+ if level == 1 :
66+ return []
67+ else :
68+ im = self ._process (im0 , im1 )
69+ level /= 2
70+ return _proc (im0 , im , level ) + [im ] + _proc (im , im1 , level )
71+
72+ return _proc (im0 , im1 , self .scale )
73+
74+ def _process (self , im0 : Image , im1 : Image ) -> Image :
4675 in_bytes0 , in_bytes1 = bytearray (im0 .tobytes ()), bytearray (im1 .tobytes ())
4776 channels = int (len (in_bytes0 ) / (im0 .width * im0 .height ))
4877 out_bytes = bytearray (len (in_bytes0 ))
@@ -62,6 +91,7 @@ def process(self, im0: Image, im1: Image) -> Image:
6291 t = time ()
6392 im0 , im1 = Image .open ("../images/0.png" ), Image .open ("../images/1.png" )
6493 rife = RIFE (0 )
65- im = rife .process (im0 , im1 )
66- im .save ("../images/out_wrapper.png" )
94+ ims = rife .process (im0 , im1 )
95+ for i , im in enumerate (ims ):
96+ im .save (f"../images/out_{ i } .png" )
6797 print (f"Elapsed time: { time () - t } s" )
0 commit comments