-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpredictor.py
More file actions
119 lines (95 loc) · 3.97 KB
/
predictor.py
File metadata and controls
119 lines (95 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import os
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from models.enhanced_resnet import EnhancedResnet
#import matplotlib
#%matplotlib inline
import matplotlib.pyplot as plt
#from data.data_utils import get_data
def plot_figs(img1,img2,img3=None, tot=2):
f = plt.figure(figsize=(6,3))
plt.axis('off')
f.add_subplot(1,tot, 1)
plt.imshow(img1)
f.add_subplot(1,tot, 2)
plt.imshow(img2)
if tot>3:
f.add_subplot(1,tot, 3)
plt.imshow(img3)
plt.show()
def imgdenoise(dt,idx,pixel=None):
model = EnhancedResnet()
#transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),(0.247, 0.243, 0.261))])
if dt == "train":
#imgs = datasets.CIFAR10('./data', train=True, download=True)
tr = datasets.CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor())
tnsr,lb = tr.__getitem__(idx)
#img = imgs.__getitem__(id)[0]
else:
#imgs = datasets.CIFAR10('./data', train=False, download=True)
tr = datasets.CIFAR10('./data', train=False, download=True, transform=transforms.ToTensor())
tnsr,lb = tr.__getitem__(idx)
#img = imgs.__getitem__(id)[0]
if pixel is not None:
x = pixel[0]
y = pixel[1]
tnsr[0][x][y] = pixel[2]/255
tnsr[1][x][y] = pixel[3]/255
tnsr[2][x][y] = pixel[4]/255
tt = torch.reshape(tnsr,(1,3,32,32))
dnl = torch.load('./utils/logs/denoiser.pth')
model.denoised_layer.load_state_dict(dnl['model'])
model.denoised_layer.eval()
wo = model.denoised_layer(tt)
wo = torch.reshape(wo,(3,32,32))
pilTrans = transforms.ToPILImage()
pilImg1 = pilTrans(tnsr)
pilImg2 = pilTrans(wo)
print("1. Original Image 2. Denoise Image")
plot_figs(pilImg1, pilImg2)
def predict(dt,idx,pixel=None):
classes = {0:"Airplane",1:"Automobile",2:"Bird",3:"Cat",4:"Deer",5:"Dog",6:"Frog",7:"Horse",8:"Ship",9:"Truck"}
model = EnhancedResnet()
if dt == "train":
imgs = datasets.CIFAR10('./data', train=True, download=True)
tr = datasets.CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor())
tnsr,lb = tr.__getitem__(idx)
img = imgs.__getitem__(idx)[0]
else:
imgs = datasets.CIFAR10('./data', train=False, download=True)
tr = datasets.CIFAR10('./data', train=False, download=True, transform=transforms.ToTensor())
tnsr,lb = tr.__getitem__(idx)
img = imgs.__getitem__(idx)[0]
if pixel is not None:
x = pixel[0]
y = pixel[1]
tnsr[0][x][y] = pixel[2]/255
tnsr[1][x][y] = pixel[3]/255
tnsr[2][x][y] = pixel[4]/255
pilTrans = transforms.ToPILImage()
pilImg = pilTrans(tnsr)
rsl = torch.load('./utils/logs/resnet.pth')
dnl = torch.load('./utils/logs/denoiser.pth')
model.denoised_layer.load_state_dict(rsl['model'])
model.residualnet.load_state_dict(dnl['model'])
model.eval()
wo = model.residualnet(tnsr)
wovals = wo.detach().numpy()
wolabel = np.argmax(wovals)
wolevel = (np.amax(wovals))*100
prdct = model(tnsr)
npvals = prdct.detach().numpy()
label = np.argmax(npvals)
level = (np.amax(npvals))*100
print("\nResnet RESULT - Actual label is "+classes[lb]+" and model predicted it as "+classes[wolabel]+" with confidence {0: .2f}".format(wolevel))
print("\nEnhancedResnet RESULT - Actual label is "+classes[lb]+" and model predicted it as "+classes[label]+" with confidence {0: .2f}".format(level))
print("1. Original Image 2. Perturbed Image")
plot_figs(img, pilImg)
if __name__ == '__main__':
# Predict class of 99th image in CIFAR10 test set
# If pixel = (x,y,c) is given then it will predict class of that preturbed image
#predict("test",id = 99,pixel=[16,16,255,255,0])
imgdenoise("test",58,[16,16,255,255,0])