-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_REBlur.py
More file actions
120 lines (107 loc) · 3.69 KB
/
test_REBlur.py
File metadata and controls
120 lines (107 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from network import RE_Net
from utils import *
import torch.nn as nn
from dataset_h5 import REBlur
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import argparse
import yaml
set_random_seed(1)
# create dataset
def create_dataset(opt):
test_dataset = concatenate_h5_datasets(
REBlur,
opt.data_path_test,
num_bin=6,
)
return test_dataset
# create dataloader
def create_dataloader(test_dataset, opt):
test_loader = DataLoader(
test_dataset,
batch_size=opt.test_batch_size,
shuffle=False)
return test_loader
# output metrics information
def log_metrics(metrics):
info = 'MSE: {:.6f} PSNR: {:.3f} SSIM: {:.3f}'.format(
metrics['MSE'].avg, metrics['PSNR'].avg, metrics['SSIM'].avg)
return info
def prepare():
global test_loader,unet,integral_net,criterion,device
# basic settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dataset
test_dataset = create_dataset(opt)
# dataloader
test_loader = create_dataloader(test_dataset, opt)
# model setting
unet = nn.DataParallel(RE_Net(out_channels=6, event_channels=6)).to(device)
# load net
unet.load_state_dict(torch.load(opt.load_path)['state_dict'])
def integral_cal_sharp(blur_image,res_pre):
res_sum_ori = 0
res_sum_ori = torch.sum(res_pre,dim = 1)
L_f = blur_image - res_sum_ori / 7
L_t = L_f.repeat(1,6,1,1) + res_pre
L_t = L_t
return L_f,L_t
def read_yaml(path):
file = open(path, 'r', encoding='utf-8')
string = file.read()
dict = yaml.safe_load(string)
return dict
from measure import mse,psnr,ssim
def detect(epoch,loader):
with torch.no_grad():
unet.eval()
pbar = tqdm(total=len(loader))
for i,item in enumerate(loader):
# load data
blur_image = item['blur_image'].float().to(device)
voxel = item['voxel'].float().to(device)
sharp = item['sharp_image'].float().to(device)
# calculation
res_pre = unet(blur_image,voxel)
L_f,L_t = integral_cal_sharp(blur_image,res_pre)
# visualization results
os.makedirs(f'Result/{epoch}', exist_ok=True)
for j in range(len(blur_image)):
save_image(np.array(blur_image[j,0].detach().cpu()),f'{epoch}/{str(i).zfill(4)}_blur_{j}')
save_image(np.array(L_f[j,0].detach().cpu()),f'{epoch}/{str(i).zfill(4)}_deblur_{j}')
save_image(np.array(sharp[j,0].detach().cpu()),f'{epoch}/{str(i).zfill(4)}_sharp_{j}')
pbar.update(1)
pbar.close()
def get_parser():
dic = read_yaml('config.yaml')
parser = argparse.ArgumentParser()
# dataset path settings
parser.add_argument("--data_path_test",default=dic['REBlur']['test'])
# train & test settings
parser.add_argument("--test_batch_size", default=dic['test_setting']['batch_size'])
# model parameter settings
parser.add_argument("--num_bin", default=dic['num_bin'])
# model loading path
parser.add_argument("--load_path",default=dic['unet']['load_path'])
# load model
parser.add_argument("--load_unet", default= dic['unet']['load'])
parser.add_argument("--seed", default= dic['seed'])
opt = parser.parse_args()
# fix bug
if opt.rgb == 'True':
opt.rgb = True
elif opt.rgb == 'False':
opt.rgb = False
if opt.load_unet == 'True':
opt.load_unet = True
elif opt.load_unet == 'False':
opt.load_unet = False
return opt
if __name__ == "__main__":
global opt
opt = get_parser()
prepare()
detect('test_REBlur',test_loader)