Skip to content

Commit 7342874

Browse files
committed
Linting + formatting pass
1 parent f42be8d commit 7342874

File tree

2 files changed

+69
-119
lines changed

2 files changed

+69
-119
lines changed

__init__.py

Lines changed: 65 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1-
import torch, os
1+
import asyncio
2+
import os
3+
import sys
4+
5+
import comfy.utils
26
import numpy as np
3-
import torchvision.transforms as transforms
7+
import torch
48
from PIL import Image
5-
import comfy.utils
6-
import asyncio
9+
from torchvision import transforms
710

811
from .Pixelization.models import c2pGen
912
from .Pixelization.models.networks import define_G
13+
from .Pixelization.test_pro import MLP_code
1014

11-
import sys
1215

1316
def has_mps() -> bool:
1417
if sys.platform != "darwin":
1518
return False
16-
else:
17-
return torch.backends.mps.is_available()
18-
19+
return torch.backends.mps.is_available()
20+
21+
1922
def get_cuda_device_string():
2023
return "cuda"
2124

25+
2226
def get_optimal_device_name():
2327
if torch.cuda.is_available():
2428
return get_cuda_device_string()
@@ -28,84 +32,30 @@ def get_optimal_device_name():
2832

2933
return "cpu"
3034

35+
3136
def get_optimal_device():
3237
return torch.device(get_optimal_device_name())
3338

3439

3540
device = get_optimal_device()
3641

37-
# From https://github.com/AUTOMATIC1111/stable-diffusion-webui-pixelization/tree/master
38-
39-
pixelize_code = [
40-
233356.8125, -27387.5918, -32866.8008, 126575.0312, -181590.0156,
41-
-31543.1289, 50374.1289, 99631.4062, -188897.3750, 138322.7031,
42-
-107266.2266, 125778.5781, 42416.1836, 139710.8594, -39614.6250,
43-
-69972.6875, -21886.4141, 86938.4766, 31457.6270, -98892.2344,
44-
-1191.5887, -61662.1719, -180121.9062, -32931.0859, 43109.0391,
45-
21490.1328, -153485.3281, 94259.1797, 43103.1992, -231953.8125,
46-
52496.7422, 142697.4062, -34882.7852, -98740.0625, 34458.5078,
47-
-135436.3438, 11420.5488, -18895.8984, -71195.4141, 176947.2344,
48-
-52747.5742, 109054.6562, -28124.9473, -17736.6152, -41327.1562,
49-
69853.3906, 79046.2656, -3923.7344, -5644.5229, 96586.7578,
50-
-89315.2656, -146578.0156, -61862.1484, -83956.4375, 87574.5703,
51-
-75055.0469, 19571.8203, 79358.7891, -16501.5000, -147169.2188,
52-
-97861.6797, 60442.1797, 40156.9023, 223136.3906, -81118.0547,
53-
-221443.6406, 54911.6914, 54735.9258, -58805.7305, -168884.4844,
54-
40865.9609, -28627.9043, -18604.7227, 120274.6172, 49712.2383,
55-
164402.7031, -53165.0820, -60664.0469, -97956.1484, -121468.4062,
56-
-69926.1484, -4889.0151, 127367.7344, 200241.0781, -85817.7578,
57-
-143190.0625, -74049.5312, 137980.5781, -150788.7656, -115719.6719,
58-
-189250.1250, -153069.7344, -127429.7891, -187588.2500, 125264.7422,
59-
-79082.3438, -114144.5781, 36033.5039, -57502.2188, 80488.1562,
60-
36501.4570, -138817.5938, -22189.6523, -222146.9688, -73292.3984,
61-
127717.2422, -183836.3750, -105907.0859, 145422.8750, 66981.2031,
62-
-9596.6699, 78099.4922, 70226.3359, 35841.8789, -116117.6016,
63-
-150986.0156, 81622.4922, 113575.0625, 154419.4844, 53586.4141,
64-
118494.8750, 131625.4375, -19763.1094, 75581.1172, -42750.5039,
65-
97934.8281, 6706.7949, -101179.0078, 83519.6172, -83054.8359,
66-
-56749.2578, -30683.6992, 54615.9492, 84061.1406, -229136.7188,
67-
-60554.0000, 8120.2622, -106468.7891, -28316.3418, -166351.3125,
68-
47797.3984, 96013.4141, 71482.9453, -101429.9297, 209063.3594,
69-
-3033.6882, -38952.5352, -84920.6719, -5895.1543, -18641.8105,
70-
47884.3633, -14620.0273, -132898.6719, -40903.5859, 197217.3750,
71-
-128599.1328, -115397.8906, -22670.7676, -78569.9688, -54559.7070,
72-
-106855.2031, 40703.1484, 55568.3164, 60202.9844, -64757.9375,
73-
-32068.8652, 160663.3438, 72187.0703, -148519.5469, 162952.8906,
74-
-128048.2031, -136153.8906, -15270.3730, -52766.3281, -52517.4531,
75-
18652.1992, 195354.2188, -136657.3750, -8034.2622, -92699.6016,
76-
-129169.1406, 188479.9844, 46003.7500, -93383.0781, -67831.6484,
77-
-66710.5469, 104338.5234, 85878.8438, -73165.2031, 95857.3203,
78-
71213.1250, 94603.1094, -30359.8125, -107989.2578, 99822.1719,
79-
184626.3594, 79238.4531, -272978.9375, -137948.5781, -145245.8125,
80-
75359.2031, 26652.7930, 50421.4141, 60784.4102, -18286.3398,
81-
-182851.9531, -87178.7969, -13131.7539, 195674.8906, 59951.7852,
82-
124353.7422, -36709.1758, -54575.4766, 77822.6953, 43697.4102,
83-
-64394.3438, 113281.1797, -93987.0703, 221989.7188, 132902.5000,
84-
-9538.8574, -14594.1338, 65084.9453, -12501.7227, 130330.6875,
85-
-115123.4766, 20823.0898, 75512.4922, -75255.7422, -41936.7656,
86-
-186678.8281, -166799.9375, 138770.6250, -78969.9531, 124516.8047,
87-
-85558.5781, -69272.4375, -115539.1094, 228774.4844, -76529.3281,
88-
-107735.8906, -76798.8906, -194335.2812, 56530.5742, -9397.7529,
89-
132985.8281, 163929.8438, -188517.7969, -141155.6406, 45071.0391,
90-
207788.3125, -125826.1172, 8965.3320, -159584.8438, 95842.4609,
91-
-76929.4688
92-
]
9342

9443
basedir = os.path.dirname(os.path.realpath(__file__))
9544
path_checkpoints = os.path.join(basedir, "checkpoints")
9645
path_pixelart_vgg19 = os.path.join(path_checkpoints, "pixelart_vgg19.pth")
9746
path_160_net_G_A = os.path.join(path_checkpoints, "160_net_G_A.pth")
9847
path_alias_net = os.path.join(path_checkpoints, "alias_net.pth")
9948

49+
10050
class TorchHijackForC2pGen:
10151
def __getattr__(self, item):
102-
if item == 'load':
52+
if item == "load":
10353
return self.load
10454

10555
if hasattr(torch, item):
10656
return getattr(torch, item)
10757

108-
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
58+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
10959

11060
def load(self, filename, *args, **kwargs):
11161
if filename == "./pixelart_vgg19.pth":
@@ -116,31 +66,37 @@ def load(self, filename, *args, **kwargs):
11666

11767
c2pGen.torch = TorchHijackForC2pGen()
11868

69+
11970
class Model(torch.nn.Module):
12071
def __init__(self):
12172
super().__init__()
12273

123-
self.G_A_net = None
124-
self.alias_net = None
125-
126-
def load(self):
12774
os.makedirs(path_checkpoints, exist_ok=True)
12875

129-
missing = False
76+
models_missing = False
13077

13178
if not os.path.exists(path_pixelart_vgg19):
132-
print(f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM")
133-
missing = True
79+
print(
80+
f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM"
81+
)
82+
models_missing = True
13483

13584
if not os.path.exists(path_160_net_G_A):
136-
print(f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az")
137-
missing = True
85+
print(
86+
f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az"
87+
)
88+
models_missing = True
13889

13990
if not os.path.exists(path_alias_net):
140-
print(f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_")
141-
missing = True
91+
print(
92+
f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_"
93+
)
94+
models_missing = True
14295

143-
assert not missing, 'Missing checkpoints for pixelization - see console for download links.'
96+
if models_missing:
97+
error_message = "Missing checkpoints for pixelization - see console for download links."
98+
print(error_message)
99+
raise RuntimeError(error_message)
144100

145101
with torch.no_grad():
146102
self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
@@ -157,16 +113,21 @@ def load(self):
157113
self.alias_net.load_state_dict(alias_state)
158114

159115

160-
def process(img):
161-
ow, oh = img.size
116+
def rescale_image(img):
117+
"""
118+
Preprocess the image for pixelization.
162119
163-
nw = int(round(ow / 4) * 4)
164-
nh = int(round(oh / 4) * 4)
120+
Crops the image to a size that is divisible by 4.
121+
"""
122+
orig_width, orig_height = img.size
165123

166-
left = (ow - nw) // 2
167-
top = (oh - nh) // 2
168-
right = left + nw
169-
bottom = top + nh
124+
new_width = int(round(orig_width / 4) * 4)
125+
new_height = int(round(orig_height / 4) * 4)
126+
127+
left = (orig_width - new_width) // 2
128+
top = (orig_height - new_height) // 2
129+
right = left + new_width
130+
bottom = top + new_height
170131

171132
img = img.crop((left, top, right, bottom))
172133

@@ -180,14 +141,16 @@ def to_image(tensor, pixel_size, upscale_after):
180141
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
181142
img = img.astype(np.uint8)
182143
img = Image.fromarray(img)
183-
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
144+
img = img.resize((img.size[0] // 4, img.size[1] // 4), resample=Image.Resampling.NEAREST)
184145
if upscale_after:
185-
img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST)
146+
img = img.resize((img.size[0] * pixel_size, img.size[1] * pixel_size), resample=Image.Resampling.NEAREST)
186147

187148
return img
188149

150+
189151
def tensor2pil(image):
190-
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
152+
return Image.fromarray(np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
153+
191154

192155
def pil2tensor(image):
193156
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
@@ -211,15 +174,14 @@ async def run_async():
211174

212175
return res[0]
213176

214-
215-
class Pixelization:
216-
model = None
217177

178+
class Pixelization:
218179
def __init__(self):
219-
pass
180+
if not hasattr(self, "model"):
181+
self.model = Model()
220182

221183
@classmethod
222-
def INPUT_TYPES(s):
184+
def INPUT_TYPES(cls):
223185
return {
224186
"required": {
225187
"image": ("IMAGE",),
@@ -241,11 +203,12 @@ async def run_pixelatization(self, image, pixel_size, upscale_after):
241203
image = image.resize((image.width * 4 // pixel_size, image.height * 4 // pixel_size))
242204

243205
with torch.no_grad():
244-
in_t = process(image).to(device)
206+
in_t = rescale_image(image).to(device)
245207

246-
feature = self.model.G_A_net.module.RGBEnc(in_t)
247-
code = torch.asarray(pixelize_code, device=device).reshape((1, 256, 1, 1))
208+
code = torch.asarray(MLP_code, device=device).reshape((1, 256, 1, 1))
248209
adain_params = self.model.G_A_net.module.MLP(code)
210+
211+
feature = self.model.G_A_net.module.RGBEnc(in_t)
249212
images = self.model.G_A_net.module.RGBDec(feature, adain_params)
250213
out_t = self.model.alias_net(images)
251214

@@ -256,32 +219,19 @@ async def run_pixelatization(self, image, pixel_size, upscale_after):
256219
return image
257220

258221
def pixelize(self, image, pixel_size, upscale_after):
259-
if self.model is None:
260-
model = Model()
261-
model.load()
262-
263-
self.model = model
264-
265222
self.model.to(device)
266223

267-
tensor = image*255
224+
tensor = image * 255
268225
tensor = np.array(tensor, dtype=np.uint8)
269226

270-
pbar = comfy.utils.ProgressBar(tensor.shape[0])
227+
progressbar = comfy.utils.ProgressBar(tensor.shape[0])
271228
all_images = []
272229
for i in range(tensor.shape[0]):
273230
image = Image.fromarray(tensor[i])
274-
all_images.append((
275-
wait_for_async(lambda: self.run_pixelatization(image, pixel_size, upscale_after))
276-
))
277-
pbar.update(1)
231+
all_images.append(wait_for_async(lambda: self.run_pixelatization(image, pixel_size, upscale_after)))
232+
progressbar.update(1)
278233

279234
return (all_images,)
280235

281236

282-
283-
284-
NODE_CLASS_MAPPINGS = {
285-
"Pixelization": Pixelization
286-
}
287-
237+
NODE_CLASS_MAPPINGS = {"Pixelization": Pixelization}

install.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import contextlib
12
import os
23
import subprocess
34

5+
46
def git(*args):
5-
return subprocess.check_call(['git'] + list(args))
7+
return subprocess.check_call(["git", *list(args)])
68

79

810
path = os.path.dirname(os.path.realpath(__file__))
@@ -13,7 +15,5 @@ def git(*args):
1315
git("checkout", "b7142536da3a9348794bce260c10e465b8bebcb8")
1416

1517
# we remove __init__ because it breaks BLIP - takes over the directory named models which BLIP also uses.
16-
try:
18+
with contextlib.suppress(OSError):
1719
os.remove(os.path.join(path, "pixelization", "models", "__init__.py"))
18-
except OSError as e:
19-
pass

0 commit comments

Comments
 (0)