Skip to content

Commit c9ebe70

Browse files
Some changes to the previous hunyuan PR. (#9725)
1 parent 261421e commit c9ebe70

File tree

4 files changed

+14
-263
lines changed

4 files changed

+14
-263
lines changed

comfy/clip_vision.py

Lines changed: 4 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -17,227 +17,10 @@ def __getitem__(self, key):
1717
def __setitem__(self, key, item):
1818
setattr(self, key, item)
1919

20-
21-
def cubic_kernel(x, a: float = -0.75):
22-
absx = x.abs()
23-
absx2 = absx ** 2
24-
absx3 = absx ** 3
25-
26-
w = (a + 2) * absx3 - (a + 3) * absx2 + 1
27-
w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a
28-
29-
return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x)))
30-
31-
def get_indices_weights(in_size, out_size, scale):
32-
# OpenCV-style half-pixel mapping
33-
x = torch.arange(out_size, dtype=torch.float32)
34-
x = (x + 0.5) / scale - 0.5
35-
36-
x0 = x.floor().long()
37-
dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3))
38-
39-
weights = cubic_kernel(dx)
40-
weights = weights / weights.sum(dim=1, keepdim=True)
41-
42-
indices = x0.unsqueeze(1) + torch.arange(-1, 3)
43-
indices = indices.clamp(0, in_size - 1)
44-
45-
return indices, weights
46-
47-
def resize_cubic_1d(x, out_size, dim):
48-
b, c, h, w = x.shape
49-
in_size = h if dim == 2 else w
50-
scale = out_size / in_size
51-
52-
indices, weights = get_indices_weights(in_size, out_size, scale)
53-
54-
if dim == 2:
55-
x = x.permute(0, 1, 3, 2)
56-
x = x.reshape(-1, h)
57-
else:
58-
x = x.reshape(-1, w)
59-
60-
gathered = x[:, indices]
61-
out = (gathered * weights.unsqueeze(0)).sum(dim=2)
62-
63-
if dim == 2:
64-
out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2)
65-
else:
66-
out = out.reshape(b, c, h, out_size)
67-
68-
return out
69-
70-
def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor:
71-
"""
72-
Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
73-
Implemented in pure PyTorch
74-
"""
75-
76-
if img.ndim == 3:
77-
img = img.unsqueeze(0)
78-
79-
img = img.permute(0, 3, 1, 2)
80-
81-
out_h, out_w = size
82-
img = resize_cubic_1d(img, out_h, dim=2)
83-
img = resize_cubic_1d(img, out_w, dim=3)
84-
return img
85-
86-
def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor:
87-
# vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
88-
original_shape = img.shape
89-
is_hwc = False
90-
91-
if img.ndim == 3:
92-
if img.shape[0] <= 4:
93-
img = img.unsqueeze(0)
94-
else:
95-
is_hwc = True
96-
img = img.permute(2, 0, 1).unsqueeze(0)
97-
elif img.ndim == 4:
98-
pass
99-
else:
100-
raise ValueError("Expected image with 3 or 4 dims.")
101-
102-
B, C, H, W = img.shape
103-
out_h, out_w = size
104-
scale_y = H / out_h
105-
scale_x = W / out_w
106-
107-
device = img.device
108-
109-
# compute the grid boundries
110-
y_start = torch.arange(out_h, device=device).float() * scale_y
111-
y_end = y_start + scale_y
112-
x_start = torch.arange(out_w, device=device).float() * scale_x
113-
x_end = x_start + scale_x
114-
115-
# for each output pixel, we will compute the range for it
116-
y_start_int = torch.floor(y_start).long()
117-
y_end_int = torch.ceil(y_end).long()
118-
x_start_int = torch.floor(x_start).long()
119-
x_end_int = torch.ceil(x_end).long()
120-
121-
# We will build the weighted sums by iterating over contributing input pixels once
122-
output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device)
123-
area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device)
124-
125-
max_kernel_h = int(torch.max(y_end_int - y_start_int).item())
126-
max_kernel_w = int(torch.max(x_end_int - x_start_int).item())
127-
128-
for dy in range(max_kernel_h):
129-
for dx in range(max_kernel_w):
130-
# compute the weights for this offset for all output pixels
131-
132-
y_idx = y_start_int.unsqueeze(1) + dy
133-
x_idx = x_start_int.unsqueeze(0) + dx
134-
135-
# clamp indices to image boundaries
136-
y_idx_clamped = torch.clamp(y_idx, 0, H - 1)
137-
x_idx_clamped = torch.clamp(x_idx, 0, W - 1)
138-
139-
# compute weights by broadcasting
140-
y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0)
141-
x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0)
142-
143-
weight = (y_weight * x_weight)
144-
145-
y_expand = y_idx_clamped.expand(out_h, out_w)
146-
x_expand = x_idx_clamped.expand(out_h, out_w)
147-
148-
149-
pixels = img[:, :, y_expand, x_expand]
150-
151-
# unsqueeze to broadcast
152-
w = weight.unsqueeze(0).unsqueeze(0)
153-
154-
output += pixels * w
155-
area += weight
156-
157-
# Normalize by area
158-
output /= area.unsqueeze(0).unsqueeze(0)
159-
160-
if is_hwc:
161-
return output[0].permute(1, 2, 0)
162-
elif img.shape[0] == 1 and original_shape[0] <= 4:
163-
return output[0]
164-
else:
165-
return output
166-
167-
def recenter(image, border_ratio: float = 0.2):
168-
169-
if image.shape[-1] == 4:
170-
mask = image[..., 3]
171-
else:
172-
mask = torch.ones_like(image[..., 0:1]) * 255
173-
image = torch.concatenate([image, mask], axis=-1)
174-
mask = mask[..., 0]
175-
176-
H, W, C = image.shape
177-
178-
size = max(H, W)
179-
result = torch.zeros((size, size, C), dtype = torch.uint8)
180-
181-
# as_tuple to match numpy behaviour
182-
x_coords, y_coords = torch.nonzero(mask, as_tuple=True)
183-
184-
y_min, y_max = y_coords.min(), y_coords.max()
185-
x_min, x_max = x_coords.min(), x_coords.max()
186-
187-
h = x_max - x_min
188-
w = y_max - y_min
189-
190-
if h == 0 or w == 0:
191-
raise ValueError('input image is empty')
192-
193-
desired_size = int(size * (1 - border_ratio))
194-
scale = desired_size / max(h, w)
195-
196-
h2 = int(h * scale)
197-
w2 = int(w * scale)
198-
199-
x2_min = (size - h2) // 2
200-
x2_max = x2_min + h2
201-
202-
y2_min = (size - w2) // 2
203-
y2_max = y2_min + w2
204-
205-
# note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
206-
result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2))
207-
208-
bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255
209-
210-
mask = result[..., 3:].to(torch.float32) / 255
211-
result = result[..., :3] * mask + bg * (1 - mask)
212-
213-
mask = mask * 255
214-
result = result.clip(0, 255).to(torch.uint8)
215-
mask = mask.clip(0, 255).to(torch.uint8)
216-
217-
return result
218-
219-
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711],
220-
crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512):
221-
222-
if border_ratio is not None:
223-
224-
image = (image * 255).clamp(0, 255).to(torch.uint8)
225-
image = [recenter(img, border_ratio = border_ratio) for img in image]
226-
227-
image = torch.stack(image, dim = 0)
228-
image = resize_cubic(image, size = (recenter_size, recenter_size))
229-
230-
image = image / 255 * 2 - 1
231-
low, high = value_range
232-
233-
image = (image - low) / (high - low)
234-
image = image.permute(0, 2, 3, 1)
235-
20+
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
23621
image = image[:, :, :, :3] if image.shape[3] > 3 else image
237-
23822
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
23923
std = torch.tensor(std, device=image.device, dtype=image.dtype)
240-
24124
image = image.movedim(-1, 1)
24225
if not (image.shape[2] == size and image.shape[3] == size):
24326
if crop:
@@ -246,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
24629
else:
24730
scale_size = (size, size)
24831

249-
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True)
32+
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
25033
h = (image.shape[2] - size)//2
25134
w = (image.shape[3] - size)//2
25235
image = image[:,:,h:h+size,w:w+size]
@@ -288,9 +71,9 @@ def load_sd(self, sd):
28871
def get_sd(self):
28972
return self.model.state_dict()
29073

291-
def encode_image(self, image, crop=True, border_ratio: float = None):
74+
def encode_image(self, image, crop=True):
29275
comfy.model_management.load_model_gpu(self.patcher)
293-
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float()
76+
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
29477
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
29578

29679
outputs = Output()

comfy/sd.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,27 +1058,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
10581058
model = None
10591059
model_patcher = None
10601060

1061-
if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]):
1062-
from collections import OrderedDict
1063-
import gc
1064-
1065-
merged_sd = OrderedDict()
1066-
1067-
for k, v in sd["model"].items():
1068-
merged_sd[f"model.{k}"] = v
1069-
1070-
for k, v in sd["vae"].items():
1071-
merged_sd[f"vae.{k}"] = v
1072-
1073-
for key, value in sd["conditioner"].items():
1074-
merged_sd[f"conditioner.{key}"] = value
1075-
1076-
sd = merged_sd
1077-
1078-
del merged_sd
1079-
gc.collect()
1080-
torch.cuda.empty_cache()
1081-
10821061
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
10831062
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
10841063
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)

nodes.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -998,31 +998,20 @@ def load_clip(self, clip_name):
998998
class CLIPVisionEncode:
999999
@classmethod
10001000
def INPUT_TYPES(s):
1001-
return {
1002-
"required": {
1003-
"clip_vision": ("CLIP_VISION",),
1004-
"image": ("IMAGE",),
1005-
"crop": (["center", "none", "recenter"],),
1006-
},
1007-
"optional": {
1008-
"border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}),
1009-
}
1010-
}
1011-
1001+
return {"required": { "clip_vision": ("CLIP_VISION",),
1002+
"image": ("IMAGE",),
1003+
"crop": (["center", "none"],)
1004+
}}
10121005
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
10131006
FUNCTION = "encode"
10141007

10151008
CATEGORY = "conditioning"
10161009

1017-
def encode(self, clip_vision, image, crop, border_ratio):
1018-
crop_image = crop == "center"
1019-
1020-
if crop == "recenter":
1021-
crop_image = True
1022-
else:
1023-
border_ratio = None
1024-
1025-
output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio)
1010+
def encode(self, clip_vision, image, crop):
1011+
crop_image = True
1012+
if crop != "center":
1013+
crop_image = False
1014+
output = clip_vision.encode_image(image, crop=crop_image)
10261015
return (output,)
10271016

10281017
class StyleModelLoader:

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ kornia>=0.7.1
2727
spandrel
2828
soundfile
2929
pydantic~=2.0
30-
pydantic-settings~=2.0
30+
pydantic-settings~=2.0

0 commit comments

Comments
 (0)