@@ -17,10 +17,227 @@ def __getitem__(self, key):
17
17
def __setitem__ (self , key , item ):
18
18
setattr (self , key , item )
19
19
20
- def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ], crop = True ):
20
+
21
+ def cubic_kernel (x , a : float = - 0.75 ):
22
+ absx = x .abs ()
23
+ absx2 = absx ** 2
24
+ absx3 = absx ** 3
25
+
26
+ w = (a + 2 ) * absx3 - (a + 3 ) * absx2 + 1
27
+ w2 = a * absx3 - 5 * a * absx2 + 8 * a * absx - 4 * a
28
+
29
+ return torch .where (absx <= 1 , w , torch .where (absx < 2 , w2 , torch .zeros_like (x )))
30
+
31
+ def get_indices_weights (in_size , out_size , scale ):
32
+ # OpenCV-style half-pixel mapping
33
+ x = torch .arange (out_size , dtype = torch .float32 )
34
+ x = (x + 0.5 ) / scale - 0.5
35
+
36
+ x0 = x .floor ().long ()
37
+ dx = x .unsqueeze (1 ) - (x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 ))
38
+
39
+ weights = cubic_kernel (dx )
40
+ weights = weights / weights .sum (dim = 1 , keepdim = True )
41
+
42
+ indices = x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 )
43
+ indices = indices .clamp (0 , in_size - 1 )
44
+
45
+ return indices , weights
46
+
47
+ def resize_cubic_1d (x , out_size , dim ):
48
+ b , c , h , w = x .shape
49
+ in_size = h if dim == 2 else w
50
+ scale = out_size / in_size
51
+
52
+ indices , weights = get_indices_weights (in_size , out_size , scale )
53
+
54
+ if dim == 2 :
55
+ x = x .permute (0 , 1 , 3 , 2 )
56
+ x = x .reshape (- 1 , h )
57
+ else :
58
+ x = x .reshape (- 1 , w )
59
+
60
+ gathered = x [:, indices ]
61
+ out = (gathered * weights .unsqueeze (0 )).sum (dim = 2 )
62
+
63
+ if dim == 2 :
64
+ out = out .reshape (b , c , w , out_size ).permute (0 , 1 , 3 , 2 )
65
+ else :
66
+ out = out .reshape (b , c , h , out_size )
67
+
68
+ return out
69
+
70
+ def resize_cubic (img : torch .Tensor , size : tuple ) -> torch .Tensor :
71
+ """
72
+ Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
73
+ Implemented in pure PyTorch
74
+ """
75
+
76
+ if img .ndim == 3 :
77
+ img = img .unsqueeze (0 )
78
+
79
+ img = img .permute (0 , 3 , 1 , 2 )
80
+
81
+ out_h , out_w = size
82
+ img = resize_cubic_1d (img , out_h , dim = 2 )
83
+ img = resize_cubic_1d (img , out_w , dim = 3 )
84
+ return img
85
+
86
+ def resize_area (img : torch .Tensor , size : tuple ) -> torch .Tensor :
87
+ # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
88
+ original_shape = img .shape
89
+ is_hwc = False
90
+
91
+ if img .ndim == 3 :
92
+ if img .shape [0 ] <= 4 :
93
+ img = img .unsqueeze (0 )
94
+ else :
95
+ is_hwc = True
96
+ img = img .permute (2 , 0 , 1 ).unsqueeze (0 )
97
+ elif img .ndim == 4 :
98
+ pass
99
+ else :
100
+ raise ValueError ("Expected image with 3 or 4 dims." )
101
+
102
+ B , C , H , W = img .shape
103
+ out_h , out_w = size
104
+ scale_y = H / out_h
105
+ scale_x = W / out_w
106
+
107
+ device = img .device
108
+
109
+ # compute the grid boundries
110
+ y_start = torch .arange (out_h , device = device ).float () * scale_y
111
+ y_end = y_start + scale_y
112
+ x_start = torch .arange (out_w , device = device ).float () * scale_x
113
+ x_end = x_start + scale_x
114
+
115
+ # for each output pixel, we will compute the range for it
116
+ y_start_int = torch .floor (y_start ).long ()
117
+ y_end_int = torch .ceil (y_end ).long ()
118
+ x_start_int = torch .floor (x_start ).long ()
119
+ x_end_int = torch .ceil (x_end ).long ()
120
+
121
+ # We will build the weighted sums by iterating over contributing input pixels once
122
+ output = torch .zeros ((B , C , out_h , out_w ), dtype = torch .float32 , device = device )
123
+ area = torch .zeros ((out_h , out_w ), dtype = torch .float32 , device = device )
124
+
125
+ max_kernel_h = int (torch .max (y_end_int - y_start_int ).item ())
126
+ max_kernel_w = int (torch .max (x_end_int - x_start_int ).item ())
127
+
128
+ for dy in range (max_kernel_h ):
129
+ for dx in range (max_kernel_w ):
130
+ # compute the weights for this offset for all output pixels
131
+
132
+ y_idx = y_start_int .unsqueeze (1 ) + dy
133
+ x_idx = x_start_int .unsqueeze (0 ) + dx
134
+
135
+ # clamp indices to image boundaries
136
+ y_idx_clamped = torch .clamp (y_idx , 0 , H - 1 )
137
+ x_idx_clamped = torch .clamp (x_idx , 0 , W - 1 )
138
+
139
+ # compute weights by broadcasting
140
+ y_weight = (torch .min (y_end .unsqueeze (1 ), y_idx_clamped .float () + 1.0 ) - torch .max (y_start .unsqueeze (1 ), y_idx_clamped .float ())).clamp (min = 0 )
141
+ x_weight = (torch .min (x_end .unsqueeze (0 ), x_idx_clamped .float () + 1.0 ) - torch .max (x_start .unsqueeze (0 ), x_idx_clamped .float ())).clamp (min = 0 )
142
+
143
+ weight = (y_weight * x_weight )
144
+
145
+ y_expand = y_idx_clamped .expand (out_h , out_w )
146
+ x_expand = x_idx_clamped .expand (out_h , out_w )
147
+
148
+
149
+ pixels = img [:, :, y_expand , x_expand ]
150
+
151
+ # unsqueeze to broadcast
152
+ w = weight .unsqueeze (0 ).unsqueeze (0 )
153
+
154
+ output += pixels * w
155
+ area += weight
156
+
157
+ # Normalize by area
158
+ output /= area .unsqueeze (0 ).unsqueeze (0 )
159
+
160
+ if is_hwc :
161
+ return output [0 ].permute (1 , 2 , 0 )
162
+ elif img .shape [0 ] == 1 and original_shape [0 ] <= 4 :
163
+ return output [0 ]
164
+ else :
165
+ return output
166
+
167
+ def recenter (image , border_ratio : float = 0.2 ):
168
+
169
+ if image .shape [- 1 ] == 4 :
170
+ mask = image [..., 3 ]
171
+ else :
172
+ mask = torch .ones_like (image [..., 0 :1 ]) * 255
173
+ image = torch .concatenate ([image , mask ], axis = - 1 )
174
+ mask = mask [..., 0 ]
175
+
176
+ H , W , C = image .shape
177
+
178
+ size = max (H , W )
179
+ result = torch .zeros ((size , size , C ), dtype = torch .uint8 )
180
+
181
+ # as_tuple to match numpy behaviour
182
+ x_coords , y_coords = torch .nonzero (mask , as_tuple = True )
183
+
184
+ y_min , y_max = y_coords .min (), y_coords .max ()
185
+ x_min , x_max = x_coords .min (), x_coords .max ()
186
+
187
+ h = x_max - x_min
188
+ w = y_max - y_min
189
+
190
+ if h == 0 or w == 0 :
191
+ raise ValueError ('input image is empty' )
192
+
193
+ desired_size = int (size * (1 - border_ratio ))
194
+ scale = desired_size / max (h , w )
195
+
196
+ h2 = int (h * scale )
197
+ w2 = int (w * scale )
198
+
199
+ x2_min = (size - h2 ) // 2
200
+ x2_max = x2_min + h2
201
+
202
+ y2_min = (size - w2 ) // 2
203
+ y2_max = y2_min + w2
204
+
205
+ # note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
206
+ result [x2_min :x2_max , y2_min :y2_max ] = resize_area (image [x_min :x_max , y_min :y_max ], (h2 , w2 ))
207
+
208
+ bg = torch .ones ((result .shape [0 ], result .shape [1 ], 3 ), dtype = torch .uint8 ) * 255
209
+
210
+ mask = result [..., 3 :].to (torch .float32 ) / 255
211
+ result = result [..., :3 ] * mask + bg * (1 - mask )
212
+
213
+ mask = mask * 255
214
+ result = result .clip (0 , 255 ).to (torch .uint8 )
215
+ mask = mask .clip (0 , 255 ).to (torch .uint8 )
216
+
217
+ return result
218
+
219
+ def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ],
220
+ crop = True , value_range = (- 1 , 1 ), border_ratio : float = None , recenter_size : int = 512 ):
221
+
222
+ if border_ratio is not None :
223
+
224
+ image = (image * 255 ).clamp (0 , 255 ).to (torch .uint8 )
225
+ image = [recenter (img , border_ratio = border_ratio ) for img in image ]
226
+
227
+ image = torch .stack (image , dim = 0 )
228
+ image = resize_cubic (image , size = (recenter_size , recenter_size ))
229
+
230
+ image = image / 255 * 2 - 1
231
+ low , high = value_range
232
+
233
+ image = (image - low ) / (high - low )
234
+ image = image .permute (0 , 2 , 3 , 1 )
235
+
21
236
image = image [:, :, :, :3 ] if image .shape [3 ] > 3 else image
237
+
22
238
mean = torch .tensor (mean , device = image .device , dtype = image .dtype )
23
239
std = torch .tensor (std , device = image .device , dtype = image .dtype )
240
+
24
241
image = image .movedim (- 1 , 1 )
25
242
if not (image .shape [2 ] == size and image .shape [3 ] == size ):
26
243
if crop :
@@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
29
246
else :
30
247
scale_size = (size , size )
31
248
32
- image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bicubic" , antialias = True )
249
+ image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bilinear" if border_ratio is not None else " bicubic" , antialias = True )
33
250
h = (image .shape [2 ] - size )// 2
34
251
w = (image .shape [3 ] - size )// 2
35
252
image = image [:,:,h :h + size ,w :w + size ]
@@ -71,9 +288,9 @@ def load_sd(self, sd):
71
288
def get_sd (self ):
72
289
return self .model .state_dict ()
73
290
74
- def encode_image (self , image , crop = True ):
291
+ def encode_image (self , image , crop = True , border_ratio : float = None ):
75
292
comfy .model_management .load_model_gpu (self .patcher )
76
- pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop ).float ()
293
+ pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop , border_ratio = border_ratio ).float ()
77
294
out = self .model (pixel_values = pixel_values , intermediate_output = 'all' if self .return_all_hidden_states else - 2 )
78
295
79
296
outputs = Output ()
@@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
136
353
json_config = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "clip_vision_config_vitl_336.json" )
137
354
else :
138
355
json_config = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "clip_vision_config_vitl.json" )
139
- elif "embeddings.patch_embeddings.projection.weight" in sd :
356
+
357
+ # Dinov2
358
+ elif 'encoder.layer.39.layer_scale2.lambda1' in sd :
140
359
json_config = os .path .join (os .path .join (os .path .dirname (os .path .realpath (__file__ )), "image_encoders" ), "dino2_giant.json" )
360
+ elif 'encoder.layer.23.layer_scale2.lambda1' in sd :
361
+ json_config = os .path .join (os .path .join (os .path .dirname (os .path .realpath (__file__ )), "image_encoders" ), "dino2_large.json" )
141
362
else :
142
363
return None
143
364
0 commit comments