-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathStyleModel.py
More file actions
executable file
·59 lines (44 loc) · 1.86 KB
/
StyleModel.py
File metadata and controls
executable file
·59 lines (44 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""voidhacktrained.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1rWEyc6vQelTf8Gig-c2PpqEcvmIxVeq4
"""
import torch
from PIL import Image
from torchvision import transforms
import re
from StyleNetArchitecture import TransformerNet
class StyleImage():
def stylize(self, model, content_image, content_scale = None):
device = torch.device("cpu")
content_image = self.load_image(content_image, scale=content_scale)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)
with torch.no_grad():
style_model = TransformerNet()
state_dict = torch.load(model)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
output = style_model(content_image).cpu()
self.save_image("static/images/styled.jpg", output[0])
def load_image(self, filename, size=None, scale=None):
img = Image.open(filename)
if size is not None:
img = img.resize((size, size), Image.ANTIALIAS)
elif scale is not None:
img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
return img
def save_image(self, filename, data):
img = data.clone().clamp(0, 255).numpy()
img = img.transpose(1, 2, 0).astype("uint8")
img = Image.fromarray(img)
img.save(filename)