@@ -185,6 +185,11 @@ def concat_cond(self, **kwargs):
185185
186186 if concat_latent_image .shape [1 :] != noise .shape [1 :]:
187187 concat_latent_image = utils .common_upscale (concat_latent_image , noise .shape [- 1 ], noise .shape [- 2 ], "bilinear" , "center" )
188+ if noise .ndim == 5 :
189+ if concat_latent_image .shape [- 3 ] < noise .shape [- 3 ]:
190+ concat_latent_image = torch .nn .functional .pad (concat_latent_image , (0 , 0 , 0 , 0 , 0 , noise .shape [- 3 ] - concat_latent_image .shape [- 3 ]), "constant" , 0 )
191+ else :
192+ concat_latent_image = concat_latent_image [:, :, :noise .shape [- 3 ]]
188193
189194 concat_latent_image = utils .resize_to_batch_size (concat_latent_image , noise .shape [0 ])
190195
@@ -213,6 +218,11 @@ def concat_cond(self, **kwargs):
213218 cond_concat .append (self .blank_inpaint_image_like (noise ))
214219 elif ck == "mask_inverted" :
215220 cond_concat .append (torch .zeros_like (noise )[:, :1 ])
221+ if ck == "concat_image" :
222+ if concat_latent_image is not None :
223+ cond_concat .append (concat_latent_image .to (device ))
224+ else :
225+ cond_concat .append (torch .zeros_like (noise ))
216226 data = torch .cat (cond_concat , dim = 1 )
217227 return data
218228 return None
@@ -872,20 +882,17 @@ def extra_conds(self, **kwargs):
872882 if cross_attn is not None :
873883 out ['c_crossattn' ] = comfy .conds .CONDRegular (cross_attn )
874884
875- image = kwargs .get ("concat_latent_image" , None )
876- noise = kwargs .get ("noise" , None )
877-
878- if image is not None :
879- padding_shape = (noise .shape [0 ], 16 , noise .shape [2 ] - 1 , noise .shape [3 ], noise .shape [4 ])
880- latent_padding = torch .zeros (padding_shape , device = noise .device , dtype = noise .dtype )
881- image_latents = torch .cat ([image .to (noise ), latent_padding ], dim = 2 )
882- out ['c_concat' ] = comfy .conds .CONDNoiseShape (self .process_latent_in (image_latents ))
883-
884885 guidance = kwargs .get ("guidance" , 6.0 )
885886 if guidance is not None :
886887 out ['guidance' ] = comfy .conds .CONDRegular (torch .FloatTensor ([guidance ]))
887888 return out
888889
890+ class HunyuanVideoSkyreelsI2V (HunyuanVideo ):
891+ def __init__ (self , model_config , model_type = ModelType .FLOW , device = None ):
892+ super ().__init__ (model_config , model_type , device = device )
893+ self .concat_keys = ("concat_image" ,)
894+
895+
889896class CosmosVideo (BaseModel ):
890897 def __init__ (self , model_config , model_type = ModelType .EDM , image_to_video = False , device = None ):
891898 super ().__init__ (model_config , model_type , device = device , unet_model = comfy .ldm .cosmos .model .GeneralDIT )
0 commit comments