Skip to content

Commit 261421e

Browse files
authored
Add Hunyuan 3D 2.1 Support (#8714)
1 parent a9f1bb1 commit 261421e

File tree

13 files changed

+1536
-128
lines changed

13 files changed

+1536
-128
lines changed

comfy/clip_vision.py

Lines changed: 226 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,227 @@ def __getitem__(self, key):
1717
def __setitem__(self, key, item):
1818
setattr(self, key, item)
1919

20-
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
20+
21+
def cubic_kernel(x, a: float = -0.75):
22+
absx = x.abs()
23+
absx2 = absx ** 2
24+
absx3 = absx ** 3
25+
26+
w = (a + 2) * absx3 - (a + 3) * absx2 + 1
27+
w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a
28+
29+
return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x)))
30+
31+
def get_indices_weights(in_size, out_size, scale):
32+
# OpenCV-style half-pixel mapping
33+
x = torch.arange(out_size, dtype=torch.float32)
34+
x = (x + 0.5) / scale - 0.5
35+
36+
x0 = x.floor().long()
37+
dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3))
38+
39+
weights = cubic_kernel(dx)
40+
weights = weights / weights.sum(dim=1, keepdim=True)
41+
42+
indices = x0.unsqueeze(1) + torch.arange(-1, 3)
43+
indices = indices.clamp(0, in_size - 1)
44+
45+
return indices, weights
46+
47+
def resize_cubic_1d(x, out_size, dim):
48+
b, c, h, w = x.shape
49+
in_size = h if dim == 2 else w
50+
scale = out_size / in_size
51+
52+
indices, weights = get_indices_weights(in_size, out_size, scale)
53+
54+
if dim == 2:
55+
x = x.permute(0, 1, 3, 2)
56+
x = x.reshape(-1, h)
57+
else:
58+
x = x.reshape(-1, w)
59+
60+
gathered = x[:, indices]
61+
out = (gathered * weights.unsqueeze(0)).sum(dim=2)
62+
63+
if dim == 2:
64+
out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2)
65+
else:
66+
out = out.reshape(b, c, h, out_size)
67+
68+
return out
69+
70+
def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor:
71+
"""
72+
Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
73+
Implemented in pure PyTorch
74+
"""
75+
76+
if img.ndim == 3:
77+
img = img.unsqueeze(0)
78+
79+
img = img.permute(0, 3, 1, 2)
80+
81+
out_h, out_w = size
82+
img = resize_cubic_1d(img, out_h, dim=2)
83+
img = resize_cubic_1d(img, out_w, dim=3)
84+
return img
85+
86+
def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor:
87+
# vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
88+
original_shape = img.shape
89+
is_hwc = False
90+
91+
if img.ndim == 3:
92+
if img.shape[0] <= 4:
93+
img = img.unsqueeze(0)
94+
else:
95+
is_hwc = True
96+
img = img.permute(2, 0, 1).unsqueeze(0)
97+
elif img.ndim == 4:
98+
pass
99+
else:
100+
raise ValueError("Expected image with 3 or 4 dims.")
101+
102+
B, C, H, W = img.shape
103+
out_h, out_w = size
104+
scale_y = H / out_h
105+
scale_x = W / out_w
106+
107+
device = img.device
108+
109+
# compute the grid boundries
110+
y_start = torch.arange(out_h, device=device).float() * scale_y
111+
y_end = y_start + scale_y
112+
x_start = torch.arange(out_w, device=device).float() * scale_x
113+
x_end = x_start + scale_x
114+
115+
# for each output pixel, we will compute the range for it
116+
y_start_int = torch.floor(y_start).long()
117+
y_end_int = torch.ceil(y_end).long()
118+
x_start_int = torch.floor(x_start).long()
119+
x_end_int = torch.ceil(x_end).long()
120+
121+
# We will build the weighted sums by iterating over contributing input pixels once
122+
output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device)
123+
area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device)
124+
125+
max_kernel_h = int(torch.max(y_end_int - y_start_int).item())
126+
max_kernel_w = int(torch.max(x_end_int - x_start_int).item())
127+
128+
for dy in range(max_kernel_h):
129+
for dx in range(max_kernel_w):
130+
# compute the weights for this offset for all output pixels
131+
132+
y_idx = y_start_int.unsqueeze(1) + dy
133+
x_idx = x_start_int.unsqueeze(0) + dx
134+
135+
# clamp indices to image boundaries
136+
y_idx_clamped = torch.clamp(y_idx, 0, H - 1)
137+
x_idx_clamped = torch.clamp(x_idx, 0, W - 1)
138+
139+
# compute weights by broadcasting
140+
y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0)
141+
x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0)
142+
143+
weight = (y_weight * x_weight)
144+
145+
y_expand = y_idx_clamped.expand(out_h, out_w)
146+
x_expand = x_idx_clamped.expand(out_h, out_w)
147+
148+
149+
pixels = img[:, :, y_expand, x_expand]
150+
151+
# unsqueeze to broadcast
152+
w = weight.unsqueeze(0).unsqueeze(0)
153+
154+
output += pixels * w
155+
area += weight
156+
157+
# Normalize by area
158+
output /= area.unsqueeze(0).unsqueeze(0)
159+
160+
if is_hwc:
161+
return output[0].permute(1, 2, 0)
162+
elif img.shape[0] == 1 and original_shape[0] <= 4:
163+
return output[0]
164+
else:
165+
return output
166+
167+
def recenter(image, border_ratio: float = 0.2):
168+
169+
if image.shape[-1] == 4:
170+
mask = image[..., 3]
171+
else:
172+
mask = torch.ones_like(image[..., 0:1]) * 255
173+
image = torch.concatenate([image, mask], axis=-1)
174+
mask = mask[..., 0]
175+
176+
H, W, C = image.shape
177+
178+
size = max(H, W)
179+
result = torch.zeros((size, size, C), dtype = torch.uint8)
180+
181+
# as_tuple to match numpy behaviour
182+
x_coords, y_coords = torch.nonzero(mask, as_tuple=True)
183+
184+
y_min, y_max = y_coords.min(), y_coords.max()
185+
x_min, x_max = x_coords.min(), x_coords.max()
186+
187+
h = x_max - x_min
188+
w = y_max - y_min
189+
190+
if h == 0 or w == 0:
191+
raise ValueError('input image is empty')
192+
193+
desired_size = int(size * (1 - border_ratio))
194+
scale = desired_size / max(h, w)
195+
196+
h2 = int(h * scale)
197+
w2 = int(w * scale)
198+
199+
x2_min = (size - h2) // 2
200+
x2_max = x2_min + h2
201+
202+
y2_min = (size - w2) // 2
203+
y2_max = y2_min + w2
204+
205+
# note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
206+
result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2))
207+
208+
bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255
209+
210+
mask = result[..., 3:].to(torch.float32) / 255
211+
result = result[..., :3] * mask + bg * (1 - mask)
212+
213+
mask = mask * 255
214+
result = result.clip(0, 255).to(torch.uint8)
215+
mask = mask.clip(0, 255).to(torch.uint8)
216+
217+
return result
218+
219+
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711],
220+
crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512):
221+
222+
if border_ratio is not None:
223+
224+
image = (image * 255).clamp(0, 255).to(torch.uint8)
225+
image = [recenter(img, border_ratio = border_ratio) for img in image]
226+
227+
image = torch.stack(image, dim = 0)
228+
image = resize_cubic(image, size = (recenter_size, recenter_size))
229+
230+
image = image / 255 * 2 - 1
231+
low, high = value_range
232+
233+
image = (image - low) / (high - low)
234+
image = image.permute(0, 2, 3, 1)
235+
21236
image = image[:, :, :, :3] if image.shape[3] > 3 else image
237+
22238
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
23239
std = torch.tensor(std, device=image.device, dtype=image.dtype)
240+
24241
image = image.movedim(-1, 1)
25242
if not (image.shape[2] == size and image.shape[3] == size):
26243
if crop:
@@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
29246
else:
30247
scale_size = (size, size)
31248

32-
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
249+
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True)
33250
h = (image.shape[2] - size)//2
34251
w = (image.shape[3] - size)//2
35252
image = image[:,:,h:h+size,w:w+size]
@@ -71,9 +288,9 @@ def load_sd(self, sd):
71288
def get_sd(self):
72289
return self.model.state_dict()
73290

74-
def encode_image(self, image, crop=True):
291+
def encode_image(self, image, crop=True, border_ratio: float = None):
75292
comfy.model_management.load_model_gpu(self.patcher)
76-
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
293+
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float()
77294
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
78295

79296
outputs = Output()
@@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
136353
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
137354
else:
138355
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
139-
elif "embeddings.patch_embeddings.projection.weight" in sd:
356+
357+
# Dinov2
358+
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
140359
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
360+
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
361+
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
141362
else:
142363
return None
143364

comfy/image_encoders/dino2.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ def __init__(self, dim, dtype, device, operations):
3131
def forward(self, x):
3232
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
3333

34+
class Dinov2MLP(torch.nn.Module):
35+
def __init__(self, hidden_size: int, dtype, device, operations):
36+
super().__init__()
37+
38+
mlp_ratio = 4
39+
hidden_features = int(hidden_size * mlp_ratio)
40+
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
41+
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
42+
43+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
44+
hidden_state = self.fc1(hidden_state)
45+
hidden_state = torch.nn.functional.gelu(hidden_state)
46+
hidden_state = self.fc2(hidden_state)
47+
return hidden_state
3448

3549
class SwiGLUFFN(torch.nn.Module):
3650
def __init__(self, dim, dtype, device, operations):
@@ -50,12 +64,15 @@ def forward(self, x):
5064

5165

5266
class Dino2Block(torch.nn.Module):
53-
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
67+
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
5468
super().__init__()
5569
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
5670
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
5771
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
58-
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
72+
if use_swiglu_ffn:
73+
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
74+
else:
75+
self.mlp = Dinov2MLP(dim, dtype, device, operations)
5976
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
6077
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
6178

@@ -66,9 +83,10 @@ def forward(self, x, optimized_attention):
6683

6784

6885
class Dino2Encoder(torch.nn.Module):
69-
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
86+
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
7087
super().__init__()
71-
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
88+
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
89+
for _ in range(num_layers)])
7290

7391
def forward(self, x, intermediate_output=None):
7492
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -78,8 +96,8 @@ def forward(self, x, intermediate_output=None):
7896
intermediate_output = len(self.layer) + intermediate_output
7997

8098
intermediate = None
81-
for i, l in enumerate(self.layer):
82-
x = l(x, optimized_attention)
99+
for i, layer in enumerate(self.layer):
100+
x = layer(x, optimized_attention)
83101
if i == intermediate_output:
84102
intermediate = x.clone()
85103
return x, intermediate
@@ -128,9 +146,10 @@ def __init__(self, config_dict, dtype, device, operations):
128146
dim = config_dict["hidden_size"]
129147
heads = config_dict["num_attention_heads"]
130148
layer_norm_eps = config_dict["layer_norm_eps"]
149+
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
131150

132151
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
133-
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
152+
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
134153
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
135154

136155
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"hidden_size": 1024,
3+
"use_mask_token": true,
4+
"patch_size": 14,
5+
"image_size": 518,
6+
"num_channels": 3,
7+
"num_attention_heads": 16,
8+
"initializer_range": 0.02,
9+
"attention_probs_dropout_prob": 0.0,
10+
"hidden_dropout_prob": 0.0,
11+
"hidden_act": "gelu",
12+
"mlp_ratio": 4,
13+
"model_type": "dinov2",
14+
"num_hidden_layers": 24,
15+
"layer_norm_eps": 1e-6,
16+
"qkv_bias": true,
17+
"use_swiglu_ffn": false,
18+
"layerscale_value": 1.0,
19+
"drop_path_rate": 0.0,
20+
"image_mean": [0.485, 0.456, 0.406],
21+
"image_std": [0.229, 0.224, 0.225]
22+
}

comfy/latent_formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ class Hunyuan3Dv2(LatentFormat):
538538
latent_dimensions = 1
539539
scale_factor = 0.9990943042622529
540540

541+
class Hunyuan3Dv2_1(LatentFormat):
542+
scale_factor = 1.0039506158752403
543+
latent_channels = 64
544+
latent_dimensions = 1
545+
541546
class Hunyuan3Dv2mini(LatentFormat):
542547
latent_channels = 64
543548
latent_dimensions = 1

0 commit comments

Comments
 (0)