-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathInference.py
More file actions
69 lines (47 loc) · 2.09 KB
/
Inference.py
File metadata and controls
69 lines (47 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import os
import sys
from models import utils
from models import model as model_hub
from torchvision.transforms import transforms
from models.model_cud import CUD_Loss
class Inferencer:
def __init__(self, args):
sys.path.insert(0, './model')
self.args = args
use_cuda = self.args.cuda and torch.cuda.is_available()
self.device = torch.device('cuda' if use_cuda else 'cpu')
self.model = self.__init_model()
checkpoint = torch.load(args.model_path)
self.model.load_state_dict(checkpoint)
self.criterion = CUD_Loss(ssim_window_size=5)
self.trans = transforms.Compose([transforms.ToTensor()
])
self.model.eval()
def start_inference(self, src_path, src_path_y=None):
path, fn = os.path.split(src_path)
fn, ext = os.path.splitext(fn)
input_rgb, input_rgb_deu, input_rgb_diff = utils.load_image_deu_stack(src_path)
input_rgb = self.trans(input_rgb)
input_rgb_deu = self.trans(input_rgb_deu)
input_rgb_diff = self.trans(input_rgb_diff)
input_src = torch.cat((input_rgb, input_rgb_deu, input_rgb_diff))
if src_path_y is not None:
target_rgb = utils.load_image(src_path_y, self.args.grey_scale)
target_src = self.trans(target_rgb)
target_src = target_src.unsqueeze(0)
input_src = input_src.unsqueeze(0)
output = self.model(input_src, fn=fn)
output = torch.clamp(output, 0.0, 1.0)
output = output.squeeze(0)
if self.args.save_figures:
utils.save_histogram(input_src, output, target_src, channel='RGB', data_name=fn)
output = utils.tensor_to_pil(output)
output = utils.m_invert(output)
return output
def __init_model(self):
model = torch.nn.DataParallel(
model_hub.CUD_NET(num_points=self.args.points,
save_figures=self.args.save_figures,
clip_threshold=self.args.clip_threshold).to(self.device))
return model