Skip to content

Commit 10c3aa6

Browse files
committed
Add support for interpolating more than one frame
Signed-off-by: ArchieMeng <archiemeng@protonmail.com>
1 parent c8734e0 commit 10c3aa6

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

src/rife_ncnn_vulkan.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
from __future__ import annotations
12
import sys
2-
from math import floor
33
from pathlib import Path
44

55
from PIL import Image
@@ -13,9 +13,21 @@
1313

1414

1515
class RIFE:
16-
def __init__(self, gpuid: int = -1, model: str = "rife-HD", tta_mode: bool = False, uhd_mode: bool = False, num_threads: int = 1):
16+
def __init__(self,
17+
gpuid: int = -1,
18+
model: str = "rife-HD",
19+
scale: int = 2,
20+
tta_mode: bool = False,
21+
uhd_mode: bool = False,
22+
num_threads: int = 1):
1723
rife_v2 = "rife-v2" in model
1824
self.model = model
25+
26+
if (scale & (scale -1)) == 0:
27+
self.scale = scale
28+
else:
29+
raise ValueError("scale should be powers of 2")
30+
1931
self._raw_rife = raw.RIFEWrapper(gpuid, tta_mode, uhd_mode, num_threads, rife_v2)
2032
self.load()
2133

@@ -42,7 +54,24 @@ def load(self, model_dir: str = ""):
4254
else:
4355
raise FileNotFoundError(f"{model_dir} not found")
4456

45-
def process(self, im0: Image, im1: Image) -> Image:
57+
def process(self, im0: Image, im1: Image) -> list[Image]:
58+
"""
59+
interpolate frames between im0 and im1
60+
:param im0: First frame
61+
:param im1: Second frame
62+
:return: the interpolation frames between im0 and im1
63+
"""
64+
def _proc(im0: Image, im1: Image, level) -> list[Image]:
65+
if level == 1:
66+
return []
67+
else:
68+
im = self._process(im0, im1)
69+
level /= 2
70+
return _proc(im0, im, level) + [im] + _proc(im, im1, level)
71+
72+
return _proc(im0, im1, self.scale)
73+
74+
def _process(self, im0: Image, im1: Image) -> Image:
4675
in_bytes0, in_bytes1 = bytearray(im0.tobytes()), bytearray(im1.tobytes())
4776
channels = int(len(in_bytes0) / (im0.width * im0.height))
4877
out_bytes = bytearray(len(in_bytes0))
@@ -62,6 +91,7 @@ def process(self, im0: Image, im1: Image) -> Image:
6291
t = time()
6392
im0, im1 = Image.open("../images/0.png"), Image.open("../images/1.png")
6493
rife = RIFE(0)
65-
im = rife.process(im0, im1)
66-
im.save("../images/out_wrapper.png")
94+
ims = rife.process(im0, im1)
95+
for i, im in enumerate(ims):
96+
im.save(f"../images/out_{i}.png")
6797
print(f"Elapsed time: {time() - t}s")

0 commit comments

Comments
 (0)