Skip to content

Commit 0838052

Browse files
author
Cabana-HPC1
committed
FEAT: Adding float32 support for RIFE
1 parent 2195932 commit 0838052

File tree

3 files changed

+134
-72
lines changed

3 files changed

+134
-72
lines changed

tomosuitepy/base/reconstruct.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,10 @@ def prepare_rife(basedir, start_row, end_row, rife_types, verbose):
504504

505505
for file in tqdm(new_files, desc='Loading Data'):
506506
original = cv2.imread(file, -1)
507-
fixed_original = rgb2gray(original)
507+
if len(np.shape(original)) == 3:
508+
fixed_original = rgb2gray(original)
509+
else:
510+
fixed_original = original
508511
if rife_types[2] == True:
509512
fixed_original = np.log(fixed_original)
510513
#fixed_original *= 255.0
@@ -515,7 +518,7 @@ def prepare_rife(basedir, start_row, end_row, rife_types, verbose):
515518
shape = prj_data.shape[0]
516519

517520
_theta = np.load(f"{basedir}extracted/theta/theta.npy")
518-
theta = np.linspace(_theta[0], theta[-1], prj_data.shape[0])
521+
theta = np.linspace(_theta[0], _theta[-1], prj_data.shape[0])
519522

520523
if verbose:
521524
print(f"The shape of this data is: {prj_data.shape}")

tomosuitepy/easy_networks/rife/data_prep.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def deal_with_sparse_angle(prj_data, theta,
6262

6363
return prj_data, theta
6464

65+
6566
def view_prj_contrast(basedir, cutoff=None, above_or_below='below',
6667
analysis_func=np.sum, plot=True):
6768

@@ -110,7 +111,6 @@ def view_prj_contrast(basedir, cutoff=None, above_or_below='below',
110111
return np.asarray(analysis_idx), prj_data[analysis_idx], theta[analysis_idx]
111112

112113

113-
114114
def create_prj_mp4(basedir, video_type='input', types='base', force_positive=False,
115115
sparse_angle_removal=0, fps=30, torf=False, apply_exp=False, prj_idx=None):
116116
"""
@@ -213,6 +213,7 @@ def create_prj_mp4(basedir, video_type='input', types='base', force_positive=Fal
213213
return prj_data, np.asarray(out_data)
214214

215215

216+
216217
def rife_predict(basedir, location_of_rife=rife_path, exp=2, scale=1.0,
217218
gpu='0', video_input_type='input',
218219
video_output_type='predicted',
@@ -348,4 +349,42 @@ def create_prj_mp4_old(basedir, output_file, types='base', sparse_angle_removal=
348349
im *= 255.0
349350
out.write((im).astype(np.uint8))
350351

351-
out.release()
352+
out.release()
353+
354+
355+
def full_res_rife(basedir, location_of_rife=rife_path, exp=2,
356+
gpu='0', output_folder='frames', python_location=''):
357+
"""
358+
Use the neural network called RIFE to upscale the amount of projections.
359+
360+
Parameters
361+
----------
362+
basedir : str
363+
Path to the project.
364+
365+
location_of_rife : str
366+
Path to the github repo of RIFE with / at the end.
367+
368+
exp : int
369+
2 to the power of exp that the frames will be upscaled by
370+
371+
gpu : str
372+
The string index of the gpu to use.
373+
374+
output_folder : str
375+
The name of the output folder to be created at {basedir}rife/{output_folder}/
376+
377+
Returns
378+
-------
379+
command
380+
A command to be used in a terminal with the RIFE conda env variables installed.
381+
"""
382+
383+
pre = f'cd {location_of_rife} &&'
384+
first = f'{python_location}python inference_img.py'
385+
second = f'--exp={exp}'
386+
third = f'--basedir={basedir}'
387+
fourth_inter = f'--gpu={gpu}'
388+
fifth = f"--output={output_folder}"
389+
390+
return f"{pre} {first} {second} {third} {fourth_inter} {fifth}"
Lines changed: 88 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
1-
import os
2-
import cv2
1+
import os, sys, warnings, cv2, argparse
2+
3+
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
4+
parser.add_argument('--basedir', type=str, required=True)
5+
parser.add_argument('--exp', default=2, type=int)
6+
parser.add_argument('--output', default='frames_check', type=str)
7+
parser.add_argument('--gpu', default=0, type=int)
8+
9+
args = parser.parse_args()
10+
args.ratio = float(0.0) # inference ratio between two images with 0 - 1 range
11+
args.rthreshold = float(0.02) # returns image when actual ratio falls in given range threshold
12+
args.rmaxcycles = int(8) # limit max number of bisectional cycles
13+
14+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
15+
316
import torch
4-
import argparse
517
from torch.nn import functional as F
618
from model.RIFE_HDv2 import Model
7-
import warnings
19+
import tifffile as tif
20+
import numpy as np
21+
from tqdm import tqdm
22+
from skimage.color import rgb2gray
23+
24+
sys.path.append('../../..')
25+
from base.common import load_extracted_prj
26+
827
warnings.filterwarnings("ignore")
928

1029
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -13,78 +32,79 @@
1332
torch.backends.cudnn.enabled = True
1433
torch.backends.cudnn.benchmark = True
1534

16-
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
17-
parser.add_argument('--img', dest='img', nargs=2, required=True)
18-
parser.add_argument('--exp', default=4, type=int)
19-
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
20-
parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
21-
parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
22-
args = parser.parse_args()
23-
2435
model = Model()
2536
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
2637
model.eval()
2738
model.device()
2839

29-
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
30-
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
31-
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
32-
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
33-
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
40+
# Load in all the images
41+
total_prjs = load_extracted_prj(args.basedir)
42+
save_location = f"{args.basedir}rife/{args.output}/"
43+
prj_max = total_prjs.max()
44+
total_prjs = total_prjs / prj_max
45+
total_prjs = total_prjs * 255.0
3446

35-
else:
36-
img0 = cv2.imread(args.img[0])
37-
img1 = cv2.imread(args.img[1])
38-
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
39-
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
47+
# Iterate over the images
48+
zfill_val = len(str(len(total_prjs) * 2**args.exp))
49+
current_frame = 0
4050

41-
n, c, h, w = img0.shape
42-
ph = ((h - 1) // 32 + 1) * 32
43-
pw = ((w - 1) // 32 + 1) * 32
44-
padding = (0, pw - w, 0, ph - h)
45-
img0 = F.pad(img0, padding)
46-
img1 = F.pad(img1, padding)
51+
for iteration in tqdm(range(0, len(total_prjs)-1), desc='Interpolation'):
52+
# Tripple the image arrays
53+
img0 = np.dstack((total_prjs[iteration], total_prjs[iteration], total_prjs[iteration]))
54+
img1 = np.dstack((total_prjs[iteration + 1], total_prjs[iteration + 1], total_prjs[iteration + 1]))
4755

56+
img0 = img0.astype(np.float32)
57+
img1 = img1.astype(np.float32)
4858

49-
if args.ratio:
50-
img_list = [img0]
51-
img0_ratio = 0.0
52-
img1_ratio = 1.0
53-
if args.ratio <= img0_ratio + args.rthreshold / 2:
54-
middle = img0
55-
elif args.ratio >= img1_ratio - args.rthreshold / 2:
56-
middle = img1
57-
else:
58-
tmp_img0 = img0
59-
tmp_img1 = img1
60-
for inference_cycle in range(args.rmaxcycles):
61-
middle = model.inference(tmp_img0, tmp_img1)
62-
middle_ratio = ( img0_ratio + img1_ratio ) / 2
63-
if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
64-
break
65-
if args.ratio > middle_ratio:
66-
tmp_img0 = middle
67-
img0_ratio = middle_ratio
68-
else:
69-
tmp_img1 = middle
70-
img1_ratio = middle_ratio
71-
img_list.append(middle)
72-
img_list.append(img1)
73-
else:
74-
img_list = [img0, img1]
75-
for i in range(args.exp):
76-
tmp = []
77-
for j in range(len(img_list) - 1):
78-
mid = model.inference(img_list[j], img_list[j + 1])
79-
tmp.append(img_list[j])
80-
tmp.append(mid)
81-
tmp.append(img1)
82-
img_list = tmp
59+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
60+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
61+
62+
n, c, h, w = img0.shape
63+
ph = ((h - 1) // 32 + 1) * 32
64+
pw = ((w - 1) // 32 + 1) * 32
65+
padding = (0, pw - w, 0, ph - h)
66+
img0 = F.pad(img0, padding)
67+
img1 = F.pad(img1, padding)
8368

84-
if not os.path.exists('output'):
85-
os.mkdir('output')
86-
for i in range(len(img_list)):
87-
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
88-
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
69+
if args.ratio:
70+
img_list = [img0]
71+
img0_ratio = 0.0
72+
img1_ratio = 1.0
73+
if args.ratio <= img0_ratio + args.rthreshold / 2:
74+
middle = img0
75+
elif args.ratio >= img1_ratio - args.rthreshold / 2:
76+
middle = img1
77+
else:
78+
tmp_img0 = img0
79+
tmp_img1 = img1
80+
for inference_cycle in range(args.rmaxcycles):
81+
middle = model.inference(tmp_img0, tmp_img1)
82+
middle_ratio = ( img0_ratio + img1_ratio ) / 2
83+
if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
84+
break
85+
if args.ratio > middle_ratio:
86+
tmp_img0 = middle
87+
img0_ratio = middle_ratio
88+
else:
89+
tmp_img1 = middle
90+
img1_ratio = middle_ratio
91+
img_list.append(middle)
92+
img_list.append(img1)
8993
else:
90-
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
94+
img_list = [img0, img1]
95+
for i in range(args.exp):
96+
tmp = []
97+
for j in range(len(img_list) - 1):
98+
mid = model.inference(img_list[j], img_list[j + 1])
99+
tmp.append(img_list[j])
100+
tmp.append(mid)
101+
tmp.append(img1)
102+
img_list = tmp
103+
104+
if not os.path.exists(save_location):
105+
os.mkdir(save_location)
106+
for i in range(len(img_list)):
107+
im2save = (img_list[i][0] * 255).cpu().numpy().transpose(1, 2, 0)[:h, :w]
108+
im2save = rgb2gray(im2save)
109+
tif.imsave(f'{save_location}/img_{str(current_frame).zfill(zfill_val)}.tif', im2save)
110+
current_frame += 1

0 commit comments

Comments
 (0)