Skip to content

Commit acc152b

Browse files
authored
Support loading and using SkyReels-V1-Hunyuan-I2V (#6862)
* Support SkyReels-V1-Hunyuan-I2V * VAE scaling * Fix T2V oops * Proper latent scaling
1 parent b07258c commit acc152b

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

comfy/ldm/hunyuan_video/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def block_wrap(args):
310310
shape[i] = shape[i] // self.patch_size[i]
311311
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
312312
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
313-
img = img.reshape(initial_shape)
313+
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
314314
return img
315315

316316
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):

comfy/model_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,15 @@ def extra_conds(self, **kwargs):
871871
if cross_attn is not None:
872872
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
873873

874+
image = kwargs.get("concat_latent_image", None)
875+
noise = kwargs.get("noise", None)
876+
877+
if image is not None:
878+
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
879+
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
880+
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
881+
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
882+
874883
guidance = kwargs.get("guidance", 6.0)
875884
if guidance is not None:
876885
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))

comfy/model_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix):
136136
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
137137
dit_config = {}
138138
dit_config["image_model"] = "hunyuan_video"
139-
dit_config["in_channels"] = 16
139+
dit_config["in_channels"] = state_dict["img_in.proj.weight"].shape[1] #SkyReels img2video has 32 input channels
140140
dit_config["patch_size"] = [1, 2, 2]
141141
dit_config["out_channels"] = 16
142142
dit_config["vec_in_dim"] = 768

0 commit comments

Comments
 (0)