1- import torch , os
1+ import asyncio
2+ import os
3+ import sys
4+
5+ import comfy .utils
26import numpy as np
3- import torchvision . transforms as transforms
7+ import torch
48from PIL import Image
5- import comfy .utils
6- import asyncio
9+ from torchvision import transforms
710
811from .Pixelization .models import c2pGen
912from .Pixelization .models .networks import define_G
13+ from .Pixelization .test_pro import MLP_code
1014
11- import sys
1215
1316def has_mps () -> bool :
1417 if sys .platform != "darwin" :
1518 return False
16- else :
17- return torch . backends . mps . is_available ()
18-
19+ return torch . backends . mps . is_available ()
20+
21+
1922def get_cuda_device_string ():
2023 return "cuda"
2124
25+
2226def get_optimal_device_name ():
2327 if torch .cuda .is_available ():
2428 return get_cuda_device_string ()
@@ -28,84 +32,30 @@ def get_optimal_device_name():
2832
2933 return "cpu"
3034
35+
3136def get_optimal_device ():
3237 return torch .device (get_optimal_device_name ())
3338
3439
3540device = get_optimal_device ()
3641
37- # From https://github.com/AUTOMATIC1111/stable-diffusion-webui-pixelization/tree/master
38-
39- pixelize_code = [
40- 233356.8125 , - 27387.5918 , - 32866.8008 , 126575.0312 , - 181590.0156 ,
41- - 31543.1289 , 50374.1289 , 99631.4062 , - 188897.3750 , 138322.7031 ,
42- - 107266.2266 , 125778.5781 , 42416.1836 , 139710.8594 , - 39614.6250 ,
43- - 69972.6875 , - 21886.4141 , 86938.4766 , 31457.6270 , - 98892.2344 ,
44- - 1191.5887 , - 61662.1719 , - 180121.9062 , - 32931.0859 , 43109.0391 ,
45- 21490.1328 , - 153485.3281 , 94259.1797 , 43103.1992 , - 231953.8125 ,
46- 52496.7422 , 142697.4062 , - 34882.7852 , - 98740.0625 , 34458.5078 ,
47- - 135436.3438 , 11420.5488 , - 18895.8984 , - 71195.4141 , 176947.2344 ,
48- - 52747.5742 , 109054.6562 , - 28124.9473 , - 17736.6152 , - 41327.1562 ,
49- 69853.3906 , 79046.2656 , - 3923.7344 , - 5644.5229 , 96586.7578 ,
50- - 89315.2656 , - 146578.0156 , - 61862.1484 , - 83956.4375 , 87574.5703 ,
51- - 75055.0469 , 19571.8203 , 79358.7891 , - 16501.5000 , - 147169.2188 ,
52- - 97861.6797 , 60442.1797 , 40156.9023 , 223136.3906 , - 81118.0547 ,
53- - 221443.6406 , 54911.6914 , 54735.9258 , - 58805.7305 , - 168884.4844 ,
54- 40865.9609 , - 28627.9043 , - 18604.7227 , 120274.6172 , 49712.2383 ,
55- 164402.7031 , - 53165.0820 , - 60664.0469 , - 97956.1484 , - 121468.4062 ,
56- - 69926.1484 , - 4889.0151 , 127367.7344 , 200241.0781 , - 85817.7578 ,
57- - 143190.0625 , - 74049.5312 , 137980.5781 , - 150788.7656 , - 115719.6719 ,
58- - 189250.1250 , - 153069.7344 , - 127429.7891 , - 187588.2500 , 125264.7422 ,
59- - 79082.3438 , - 114144.5781 , 36033.5039 , - 57502.2188 , 80488.1562 ,
60- 36501.4570 , - 138817.5938 , - 22189.6523 , - 222146.9688 , - 73292.3984 ,
61- 127717.2422 , - 183836.3750 , - 105907.0859 , 145422.8750 , 66981.2031 ,
62- - 9596.6699 , 78099.4922 , 70226.3359 , 35841.8789 , - 116117.6016 ,
63- - 150986.0156 , 81622.4922 , 113575.0625 , 154419.4844 , 53586.4141 ,
64- 118494.8750 , 131625.4375 , - 19763.1094 , 75581.1172 , - 42750.5039 ,
65- 97934.8281 , 6706.7949 , - 101179.0078 , 83519.6172 , - 83054.8359 ,
66- - 56749.2578 , - 30683.6992 , 54615.9492 , 84061.1406 , - 229136.7188 ,
67- - 60554.0000 , 8120.2622 , - 106468.7891 , - 28316.3418 , - 166351.3125 ,
68- 47797.3984 , 96013.4141 , 71482.9453 , - 101429.9297 , 209063.3594 ,
69- - 3033.6882 , - 38952.5352 , - 84920.6719 , - 5895.1543 , - 18641.8105 ,
70- 47884.3633 , - 14620.0273 , - 132898.6719 , - 40903.5859 , 197217.3750 ,
71- - 128599.1328 , - 115397.8906 , - 22670.7676 , - 78569.9688 , - 54559.7070 ,
72- - 106855.2031 , 40703.1484 , 55568.3164 , 60202.9844 , - 64757.9375 ,
73- - 32068.8652 , 160663.3438 , 72187.0703 , - 148519.5469 , 162952.8906 ,
74- - 128048.2031 , - 136153.8906 , - 15270.3730 , - 52766.3281 , - 52517.4531 ,
75- 18652.1992 , 195354.2188 , - 136657.3750 , - 8034.2622 , - 92699.6016 ,
76- - 129169.1406 , 188479.9844 , 46003.7500 , - 93383.0781 , - 67831.6484 ,
77- - 66710.5469 , 104338.5234 , 85878.8438 , - 73165.2031 , 95857.3203 ,
78- 71213.1250 , 94603.1094 , - 30359.8125 , - 107989.2578 , 99822.1719 ,
79- 184626.3594 , 79238.4531 , - 272978.9375 , - 137948.5781 , - 145245.8125 ,
80- 75359.2031 , 26652.7930 , 50421.4141 , 60784.4102 , - 18286.3398 ,
81- - 182851.9531 , - 87178.7969 , - 13131.7539 , 195674.8906 , 59951.7852 ,
82- 124353.7422 , - 36709.1758 , - 54575.4766 , 77822.6953 , 43697.4102 ,
83- - 64394.3438 , 113281.1797 , - 93987.0703 , 221989.7188 , 132902.5000 ,
84- - 9538.8574 , - 14594.1338 , 65084.9453 , - 12501.7227 , 130330.6875 ,
85- - 115123.4766 , 20823.0898 , 75512.4922 , - 75255.7422 , - 41936.7656 ,
86- - 186678.8281 , - 166799.9375 , 138770.6250 , - 78969.9531 , 124516.8047 ,
87- - 85558.5781 , - 69272.4375 , - 115539.1094 , 228774.4844 , - 76529.3281 ,
88- - 107735.8906 , - 76798.8906 , - 194335.2812 , 56530.5742 , - 9397.7529 ,
89- 132985.8281 , 163929.8438 , - 188517.7969 , - 141155.6406 , 45071.0391 ,
90- 207788.3125 , - 125826.1172 , 8965.3320 , - 159584.8438 , 95842.4609 ,
91- - 76929.4688
92- ]
9342
9443basedir = os .path .dirname (os .path .realpath (__file__ ))
9544path_checkpoints = os .path .join (basedir , "checkpoints" )
9645path_pixelart_vgg19 = os .path .join (path_checkpoints , "pixelart_vgg19.pth" )
9746path_160_net_G_A = os .path .join (path_checkpoints , "160_net_G_A.pth" )
9847path_alias_net = os .path .join (path_checkpoints , "alias_net.pth" )
9948
49+
10050class TorchHijackForC2pGen :
10151 def __getattr__ (self , item ):
102- if item == ' load' :
52+ if item == " load" :
10353 return self .load
10454
10555 if hasattr (torch , item ):
10656 return getattr (torch , item )
10757
108- raise AttributeError ("'{}' object has no attribute '{}'" . format ( type ( self ). __name__ , item ) )
58+ raise AttributeError (f "'{ type ( self ). __name__ } ' object has no attribute '{ item } '" )
10959
11060 def load (self , filename , * args , ** kwargs ):
11161 if filename == "./pixelart_vgg19.pth" :
@@ -116,31 +66,37 @@ def load(self, filename, *args, **kwargs):
11666
11767c2pGen .torch = TorchHijackForC2pGen ()
11868
69+
11970class Model (torch .nn .Module ):
12071 def __init__ (self ):
12172 super ().__init__ ()
12273
123- self .G_A_net = None
124- self .alias_net = None
125-
126- def load (self ):
12774 os .makedirs (path_checkpoints , exist_ok = True )
12875
129- missing = False
76+ models_missing = False
13077
13178 if not os .path .exists (path_pixelart_vgg19 ):
132- print (f"Missing { path_pixelart_vgg19 } - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM" )
133- missing = True
79+ print (
80+ f"Missing { path_pixelart_vgg19 } - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM"
81+ )
82+ models_missing = True
13483
13584 if not os .path .exists (path_160_net_G_A ):
136- print (f"Missing { path_160_net_G_A } - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az" )
137- missing = True
85+ print (
86+ f"Missing { path_160_net_G_A } - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az"
87+ )
88+ models_missing = True
13889
13990 if not os .path .exists (path_alias_net ):
140- print (f"Missing { path_alias_net } - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_" )
141- missing = True
91+ print (
92+ f"Missing { path_alias_net } - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_"
93+ )
94+ models_missing = True
14295
143- assert not missing , 'Missing checkpoints for pixelization - see console for download links.'
96+ if models_missing :
97+ error_message = "Missing checkpoints for pixelization - see console for download links."
98+ print (error_message )
99+ raise RuntimeError (error_message )
144100
145101 with torch .no_grad ():
146102 self .G_A_net = define_G (3 , 3 , 64 , "c2pGen" , "instance" , False , "normal" , 0.02 , [0 ])
@@ -157,16 +113,21 @@ def load(self):
157113 self .alias_net .load_state_dict (alias_state )
158114
159115
160- def process (img ):
161- ow , oh = img .size
116+ def rescale_image (img ):
117+ """
118+ Preprocess the image for pixelization.
162119
163- nw = int (round (ow / 4 ) * 4 )
164- nh = int (round (oh / 4 ) * 4 )
120+ Crops the image to a size that is divisible by 4.
121+ """
122+ orig_width , orig_height = img .size
165123
166- left = (ow - nw ) // 2
167- top = (oh - nh ) // 2
168- right = left + nw
169- bottom = top + nh
124+ new_width = int (round (orig_width / 4 ) * 4 )
125+ new_height = int (round (orig_height / 4 ) * 4 )
126+
127+ left = (orig_width - new_width ) // 2
128+ top = (orig_height - new_height ) // 2
129+ right = left + new_width
130+ bottom = top + new_height
170131
171132 img = img .crop ((left , top , right , bottom ))
172133
@@ -180,14 +141,16 @@ def to_image(tensor, pixel_size, upscale_after):
180141 img = (np .transpose (img , (1 , 2 , 0 )) + 1 ) / 2.0 * 255.0
181142 img = img .astype (np .uint8 )
182143 img = Image .fromarray (img )
183- img = img .resize ((img .size [0 ]// 4 , img .size [1 ]// 4 ), resample = Image .Resampling .NEAREST )
144+ img = img .resize ((img .size [0 ] // 4 , img .size [1 ] // 4 ), resample = Image .Resampling .NEAREST )
184145 if upscale_after :
185- img = img .resize ((img .size [0 ]* pixel_size , img .size [1 ]* pixel_size ), resample = Image .Resampling .NEAREST )
146+ img = img .resize ((img .size [0 ] * pixel_size , img .size [1 ] * pixel_size ), resample = Image .Resampling .NEAREST )
186147
187148 return img
188149
150+
189151def tensor2pil (image ):
190- return Image .fromarray (np .clip (255. * image .cpu ().numpy ().squeeze (), 0 , 255 ).astype (np .uint8 ))
152+ return Image .fromarray (np .clip (255.0 * image .cpu ().numpy ().squeeze (), 0 , 255 ).astype (np .uint8 ))
153+
191154
192155def pil2tensor (image ):
193156 return torch .from_numpy (np .array (image ).astype (np .float32 ) / 255.0 ).unsqueeze (0 )
@@ -211,15 +174,14 @@ async def run_async():
211174
212175 return res [0 ]
213176
214-
215- class Pixelization :
216- model = None
217177
178+ class Pixelization :
218179 def __init__ (self ):
219- pass
180+ if not hasattr (self , "model" ):
181+ self .model = Model ()
220182
221183 @classmethod
222- def INPUT_TYPES (s ):
184+ def INPUT_TYPES (cls ):
223185 return {
224186 "required" : {
225187 "image" : ("IMAGE" ,),
@@ -241,11 +203,12 @@ async def run_pixelatization(self, image, pixel_size, upscale_after):
241203 image = image .resize ((image .width * 4 // pixel_size , image .height * 4 // pixel_size ))
242204
243205 with torch .no_grad ():
244- in_t = process (image ).to (device )
206+ in_t = rescale_image (image ).to (device )
245207
246- feature = self .model .G_A_net .module .RGBEnc (in_t )
247- code = torch .asarray (pixelize_code , device = device ).reshape ((1 , 256 , 1 , 1 ))
208+ code = torch .asarray (MLP_code , device = device ).reshape ((1 , 256 , 1 , 1 ))
248209 adain_params = self .model .G_A_net .module .MLP (code )
210+
211+ feature = self .model .G_A_net .module .RGBEnc (in_t )
249212 images = self .model .G_A_net .module .RGBDec (feature , adain_params )
250213 out_t = self .model .alias_net (images )
251214
@@ -256,32 +219,19 @@ async def run_pixelatization(self, image, pixel_size, upscale_after):
256219 return image
257220
258221 def pixelize (self , image , pixel_size , upscale_after ):
259- if self .model is None :
260- model = Model ()
261- model .load ()
262-
263- self .model = model
264-
265222 self .model .to (device )
266223
267- tensor = image * 255
224+ tensor = image * 255
268225 tensor = np .array (tensor , dtype = np .uint8 )
269226
270- pbar = comfy .utils .ProgressBar (tensor .shape [0 ])
227+ progressbar = comfy .utils .ProgressBar (tensor .shape [0 ])
271228 all_images = []
272229 for i in range (tensor .shape [0 ]):
273230 image = Image .fromarray (tensor [i ])
274- all_images .append ((
275- wait_for_async (lambda : self .run_pixelatization (image , pixel_size , upscale_after ))
276- ))
277- pbar .update (1 )
231+ all_images .append (wait_for_async (lambda : self .run_pixelatization (image , pixel_size , upscale_after )))
232+ progressbar .update (1 )
278233
279234 return (all_images ,)
280235
281236
282-
283-
284- NODE_CLASS_MAPPINGS = {
285- "Pixelization" : Pixelization
286- }
287-
237+ NODE_CLASS_MAPPINGS = {"Pixelization" : Pixelization }
0 commit comments