-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimage_generation.py
More file actions
119 lines (86 loc) · 3.77 KB
/
image_generation.py
File metadata and controls
119 lines (86 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
from PIL import Image
from utils import Utils
from torchvision import transforms
class ImageGeneration:
def __get_resNetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) -> ResNetGenerator:
netG = ResNetGenerator(input_nc, output_nc, ngf, n_blocks)
print(netG)
return netG
def __get_transform_compose() -> transforms.Compose:
return transforms.Compose([transforms.Resize(256),
transforms.ToTensor()])
def start_pipeline(img: Image.ImageFile):
if (not img):
print("transform or img is None, (invalid inputs)")
return
model_data = Utils.load_pretrained_weights("utils_files/horse2zebra_0.4.0.pth")
netG = ImageGeneration.__get_resNetGenerator()
Utils.compare_models(model_data, netG)
netG.load_state_dict(model_data)
netG.eval()
transform = ImageGeneration.__get_transform_compose()
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
batch_out = netG(batch_t)
print(batch_out)
return batch_out
def normalize_img(batch_out, mustShow: bool = True, mustSave: bool = True):
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
print(out_img)
if mustShow: out_img.show()
if mustSave: out_img.save("imgs/horse2zebra_result.jpg")
class ResNetBlock(nn.Module):
def __init__(self, dim):
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)
def build_conv_block(self, dim):
conv_block = []
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class ResNetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf, n_blocks):
assert(n_blocks >= 0)
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)