Skip to content

Commit 7c7c70c

Browse files
Refactor skyreels i2v code.
1 parent 8362199 commit 7c7c70c

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

comfy/model_base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def concat_cond(self, **kwargs):
185185

186186
if concat_latent_image.shape[1:] != noise.shape[1:]:
187187
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
188+
if noise.ndim == 5:
189+
if concat_latent_image.shape[-3] < noise.shape[-3]:
190+
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
191+
else:
192+
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
188193

189194
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
190195

@@ -213,6 +218,11 @@ def concat_cond(self, **kwargs):
213218
cond_concat.append(self.blank_inpaint_image_like(noise))
214219
elif ck == "mask_inverted":
215220
cond_concat.append(torch.zeros_like(noise)[:, :1])
221+
if ck == "concat_image":
222+
if concat_latent_image is not None:
223+
cond_concat.append(concat_latent_image.to(device))
224+
else:
225+
cond_concat.append(torch.zeros_like(noise))
216226
data = torch.cat(cond_concat, dim=1)
217227
return data
218228
return None
@@ -872,20 +882,17 @@ def extra_conds(self, **kwargs):
872882
if cross_attn is not None:
873883
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
874884

875-
image = kwargs.get("concat_latent_image", None)
876-
noise = kwargs.get("noise", None)
877-
878-
if image is not None:
879-
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
880-
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
881-
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
882-
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
883-
884885
guidance = kwargs.get("guidance", 6.0)
885886
if guidance is not None:
886887
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
887888
return out
888889

890+
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
891+
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
892+
super().__init__(model_config, model_type, device=device)
893+
self.concat_keys = ("concat_image",)
894+
895+
889896
class CosmosVideo(BaseModel):
890897
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
891898
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)

comfy/supported_models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,16 @@ def clip_target(self, state_dict={}):
826826
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
827827
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
828828

829+
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
830+
unet_config = {
831+
"image_model": "hunyuan_video",
832+
"in_channels": 32,
833+
}
834+
835+
def get_model(self, state_dict, prefix="", device=None):
836+
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
837+
return out
838+
829839
class CosmosT2V(supported_models_base.BASE):
830840
unet_config = {
831841
"image_model": "cosmos",
@@ -939,6 +949,6 @@ def get_model(self, state_dict, prefix="", device=None):
939949
out = model_base.WAN21(self, image_to_video=True, device=device)
940950
return out
941951

942-
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
952+
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
943953

944954
models += [SVD_img2vid]

0 commit comments

Comments
 (0)