@@ -17,227 +17,10 @@ def __getitem__(self, key):
17
17
def __setitem__ (self , key , item ):
18
18
setattr (self , key , item )
19
19
20
-
21
- def cubic_kernel (x , a : float = - 0.75 ):
22
- absx = x .abs ()
23
- absx2 = absx ** 2
24
- absx3 = absx ** 3
25
-
26
- w = (a + 2 ) * absx3 - (a + 3 ) * absx2 + 1
27
- w2 = a * absx3 - 5 * a * absx2 + 8 * a * absx - 4 * a
28
-
29
- return torch .where (absx <= 1 , w , torch .where (absx < 2 , w2 , torch .zeros_like (x )))
30
-
31
- def get_indices_weights (in_size , out_size , scale ):
32
- # OpenCV-style half-pixel mapping
33
- x = torch .arange (out_size , dtype = torch .float32 )
34
- x = (x + 0.5 ) / scale - 0.5
35
-
36
- x0 = x .floor ().long ()
37
- dx = x .unsqueeze (1 ) - (x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 ))
38
-
39
- weights = cubic_kernel (dx )
40
- weights = weights / weights .sum (dim = 1 , keepdim = True )
41
-
42
- indices = x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 )
43
- indices = indices .clamp (0 , in_size - 1 )
44
-
45
- return indices , weights
46
-
47
- def resize_cubic_1d (x , out_size , dim ):
48
- b , c , h , w = x .shape
49
- in_size = h if dim == 2 else w
50
- scale = out_size / in_size
51
-
52
- indices , weights = get_indices_weights (in_size , out_size , scale )
53
-
54
- if dim == 2 :
55
- x = x .permute (0 , 1 , 3 , 2 )
56
- x = x .reshape (- 1 , h )
57
- else :
58
- x = x .reshape (- 1 , w )
59
-
60
- gathered = x [:, indices ]
61
- out = (gathered * weights .unsqueeze (0 )).sum (dim = 2 )
62
-
63
- if dim == 2 :
64
- out = out .reshape (b , c , w , out_size ).permute (0 , 1 , 3 , 2 )
65
- else :
66
- out = out .reshape (b , c , h , out_size )
67
-
68
- return out
69
-
70
- def resize_cubic (img : torch .Tensor , size : tuple ) -> torch .Tensor :
71
- """
72
- Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
73
- Implemented in pure PyTorch
74
- """
75
-
76
- if img .ndim == 3 :
77
- img = img .unsqueeze (0 )
78
-
79
- img = img .permute (0 , 3 , 1 , 2 )
80
-
81
- out_h , out_w = size
82
- img = resize_cubic_1d (img , out_h , dim = 2 )
83
- img = resize_cubic_1d (img , out_w , dim = 3 )
84
- return img
85
-
86
- def resize_area (img : torch .Tensor , size : tuple ) -> torch .Tensor :
87
- # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
88
- original_shape = img .shape
89
- is_hwc = False
90
-
91
- if img .ndim == 3 :
92
- if img .shape [0 ] <= 4 :
93
- img = img .unsqueeze (0 )
94
- else :
95
- is_hwc = True
96
- img = img .permute (2 , 0 , 1 ).unsqueeze (0 )
97
- elif img .ndim == 4 :
98
- pass
99
- else :
100
- raise ValueError ("Expected image with 3 or 4 dims." )
101
-
102
- B , C , H , W = img .shape
103
- out_h , out_w = size
104
- scale_y = H / out_h
105
- scale_x = W / out_w
106
-
107
- device = img .device
108
-
109
- # compute the grid boundries
110
- y_start = torch .arange (out_h , device = device ).float () * scale_y
111
- y_end = y_start + scale_y
112
- x_start = torch .arange (out_w , device = device ).float () * scale_x
113
- x_end = x_start + scale_x
114
-
115
- # for each output pixel, we will compute the range for it
116
- y_start_int = torch .floor (y_start ).long ()
117
- y_end_int = torch .ceil (y_end ).long ()
118
- x_start_int = torch .floor (x_start ).long ()
119
- x_end_int = torch .ceil (x_end ).long ()
120
-
121
- # We will build the weighted sums by iterating over contributing input pixels once
122
- output = torch .zeros ((B , C , out_h , out_w ), dtype = torch .float32 , device = device )
123
- area = torch .zeros ((out_h , out_w ), dtype = torch .float32 , device = device )
124
-
125
- max_kernel_h = int (torch .max (y_end_int - y_start_int ).item ())
126
- max_kernel_w = int (torch .max (x_end_int - x_start_int ).item ())
127
-
128
- for dy in range (max_kernel_h ):
129
- for dx in range (max_kernel_w ):
130
- # compute the weights for this offset for all output pixels
131
-
132
- y_idx = y_start_int .unsqueeze (1 ) + dy
133
- x_idx = x_start_int .unsqueeze (0 ) + dx
134
-
135
- # clamp indices to image boundaries
136
- y_idx_clamped = torch .clamp (y_idx , 0 , H - 1 )
137
- x_idx_clamped = torch .clamp (x_idx , 0 , W - 1 )
138
-
139
- # compute weights by broadcasting
140
- y_weight = (torch .min (y_end .unsqueeze (1 ), y_idx_clamped .float () + 1.0 ) - torch .max (y_start .unsqueeze (1 ), y_idx_clamped .float ())).clamp (min = 0 )
141
- x_weight = (torch .min (x_end .unsqueeze (0 ), x_idx_clamped .float () + 1.0 ) - torch .max (x_start .unsqueeze (0 ), x_idx_clamped .float ())).clamp (min = 0 )
142
-
143
- weight = (y_weight * x_weight )
144
-
145
- y_expand = y_idx_clamped .expand (out_h , out_w )
146
- x_expand = x_idx_clamped .expand (out_h , out_w )
147
-
148
-
149
- pixels = img [:, :, y_expand , x_expand ]
150
-
151
- # unsqueeze to broadcast
152
- w = weight .unsqueeze (0 ).unsqueeze (0 )
153
-
154
- output += pixels * w
155
- area += weight
156
-
157
- # Normalize by area
158
- output /= area .unsqueeze (0 ).unsqueeze (0 )
159
-
160
- if is_hwc :
161
- return output [0 ].permute (1 , 2 , 0 )
162
- elif img .shape [0 ] == 1 and original_shape [0 ] <= 4 :
163
- return output [0 ]
164
- else :
165
- return output
166
-
167
- def recenter (image , border_ratio : float = 0.2 ):
168
-
169
- if image .shape [- 1 ] == 4 :
170
- mask = image [..., 3 ]
171
- else :
172
- mask = torch .ones_like (image [..., 0 :1 ]) * 255
173
- image = torch .concatenate ([image , mask ], axis = - 1 )
174
- mask = mask [..., 0 ]
175
-
176
- H , W , C = image .shape
177
-
178
- size = max (H , W )
179
- result = torch .zeros ((size , size , C ), dtype = torch .uint8 )
180
-
181
- # as_tuple to match numpy behaviour
182
- x_coords , y_coords = torch .nonzero (mask , as_tuple = True )
183
-
184
- y_min , y_max = y_coords .min (), y_coords .max ()
185
- x_min , x_max = x_coords .min (), x_coords .max ()
186
-
187
- h = x_max - x_min
188
- w = y_max - y_min
189
-
190
- if h == 0 or w == 0 :
191
- raise ValueError ('input image is empty' )
192
-
193
- desired_size = int (size * (1 - border_ratio ))
194
- scale = desired_size / max (h , w )
195
-
196
- h2 = int (h * scale )
197
- w2 = int (w * scale )
198
-
199
- x2_min = (size - h2 ) // 2
200
- x2_max = x2_min + h2
201
-
202
- y2_min = (size - w2 ) // 2
203
- y2_max = y2_min + w2
204
-
205
- # note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
206
- result [x2_min :x2_max , y2_min :y2_max ] = resize_area (image [x_min :x_max , y_min :y_max ], (h2 , w2 ))
207
-
208
- bg = torch .ones ((result .shape [0 ], result .shape [1 ], 3 ), dtype = torch .uint8 ) * 255
209
-
210
- mask = result [..., 3 :].to (torch .float32 ) / 255
211
- result = result [..., :3 ] * mask + bg * (1 - mask )
212
-
213
- mask = mask * 255
214
- result = result .clip (0 , 255 ).to (torch .uint8 )
215
- mask = mask .clip (0 , 255 ).to (torch .uint8 )
216
-
217
- return result
218
-
219
- def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ],
220
- crop = True , value_range = (- 1 , 1 ), border_ratio : float = None , recenter_size : int = 512 ):
221
-
222
- if border_ratio is not None :
223
-
224
- image = (image * 255 ).clamp (0 , 255 ).to (torch .uint8 )
225
- image = [recenter (img , border_ratio = border_ratio ) for img in image ]
226
-
227
- image = torch .stack (image , dim = 0 )
228
- image = resize_cubic (image , size = (recenter_size , recenter_size ))
229
-
230
- image = image / 255 * 2 - 1
231
- low , high = value_range
232
-
233
- image = (image - low ) / (high - low )
234
- image = image .permute (0 , 2 , 3 , 1 )
235
-
20
+ def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ], crop = True ):
236
21
image = image [:, :, :, :3 ] if image .shape [3 ] > 3 else image
237
-
238
22
mean = torch .tensor (mean , device = image .device , dtype = image .dtype )
239
23
std = torch .tensor (std , device = image .device , dtype = image .dtype )
240
-
241
24
image = image .movedim (- 1 , 1 )
242
25
if not (image .shape [2 ] == size and image .shape [3 ] == size ):
243
26
if crop :
@@ -246,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
246
29
else :
247
30
scale_size = (size , size )
248
31
249
- image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bilinear" if border_ratio is not None else " bicubic" , antialias = True )
32
+ image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bicubic" , antialias = True )
250
33
h = (image .shape [2 ] - size )// 2
251
34
w = (image .shape [3 ] - size )// 2
252
35
image = image [:,:,h :h + size ,w :w + size ]
@@ -288,9 +71,9 @@ def load_sd(self, sd):
288
71
def get_sd (self ):
289
72
return self .model .state_dict ()
290
73
291
- def encode_image (self , image , crop = True , border_ratio : float = None ):
74
+ def encode_image (self , image , crop = True ):
292
75
comfy .model_management .load_model_gpu (self .patcher )
293
- pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop , border_ratio = border_ratio ).float ()
76
+ pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop ).float ()
294
77
out = self .model (pixel_values = pixel_values , intermediate_output = 'all' if self .return_all_hidden_states else - 2 )
295
78
296
79
outputs = Output ()
0 commit comments