@@ -1227,22 +1227,23 @@ def extra_conds(self, **kwargs):
1227
1227
if audio_embed is not None :
1228
1228
out ['audio_embed' ] = comfy .conds .CONDRegular (audio_embed )
1229
1229
1230
- if "c_concat" not in out or "concat_latent_image" in kwargs : # 1.7B model OR I2V mode
1231
- reference_latents = kwargs . get ( "reference_latents" , None )
1232
- if reference_latents is not None :
1233
- out ['reference_latent' ] = comfy .conds .CONDRegular (self .process_latent_in (reference_latents [- 1 ]))
1230
+ reference_latents = kwargs . get ( "reference_latents" , None )
1231
+
1232
+ if "c_concat" not in out and reference_latents is not None and reference_latents [ 0 ]. shape [ 1 ] == 16 : # 1.7B model
1233
+ out ['reference_latent' ] = comfy .conds .CONDRegular (self .process_latent_in (reference_latents [- 1 ]))
1234
1234
else :
1235
- noise_shape = list (noise .shape )
1236
- noise_shape [1 ] += 4
1237
- concat_latent = torch .zeros (noise_shape , device = noise .device , dtype = noise .dtype )
1238
- zero_vae_values_first = torch .tensor ([0.8660 , - 0.4326 , - 0.0017 , - 0.4884 , - 0.5283 , 0.9207 , - 0.9896 , 0.4433 , - 0.5543 , - 0.0113 , 0.5753 , - 0.6000 , - 0.8346 , - 0.3497 , - 0.1926 , - 0.6938 ]).view (1 , 16 , 1 , 1 , 1 )
1239
- zero_vae_values_second = torch .tensor ([1.0869 , - 1.2370 , 0.0206 , - 0.4357 , - 0.6411 , 2.0307 , - 1.5972 , 1.2659 , - 0.8595 , - 0.4654 , 0.9638 , - 1.6330 , - 1.4310 , - 0.1098 , - 0.3856 , - 1.4583 ]).view (1 , 16 , 1 , 1 , 1 )
1240
- zero_vae_values = torch .tensor ([0.8642 , - 1.8583 , 0.1577 , 0.1350 , - 0.3641 , 2.5863 , - 1.9670 , 1.6065 , - 1.0475 , - 0.8678 , 1.1734 , - 1.8138 , - 1.5933 , - 0.7721 , - 0.3289 , - 1.3745 ]).view (1 , 16 , 1 , 1 , 1 )
1241
- concat_latent [:, 4 :] = zero_vae_values
1242
- concat_latent [:, 4 :, :1 ] = zero_vae_values_first
1243
- concat_latent [:, 4 :, 1 :2 ] = zero_vae_values_second
1244
- out ['c_concat' ] = comfy .conds .CONDNoiseShape (concat_latent )
1245
- reference_latents = kwargs .get ("reference_latents" , None )
1235
+ concat_latent_image = kwargs .get ("concat_latent_image" , None )
1236
+ if concat_latent_image is None :
1237
+ noise_shape = list (noise .shape )
1238
+ noise_shape [1 ] += 4
1239
+ concat_latent = torch .zeros (noise_shape , device = noise .device , dtype = noise .dtype )
1240
+ zero_vae_values_first = torch .tensor ([0.8660 , - 0.4326 , - 0.0017 , - 0.4884 , - 0.5283 , 0.9207 , - 0.9896 , 0.4433 , - 0.5543 , - 0.0113 , 0.5753 , - 0.6000 , - 0.8346 , - 0.3497 , - 0.1926 , - 0.6938 ]).view (1 , 16 , 1 , 1 , 1 )
1241
+ zero_vae_values_second = torch .tensor ([1.0869 , - 1.2370 , 0.0206 , - 0.4357 , - 0.6411 , 2.0307 , - 1.5972 , 1.2659 , - 0.8595 , - 0.4654 , 0.9638 , - 1.6330 , - 1.4310 , - 0.1098 , - 0.3856 , - 1.4583 ]).view (1 , 16 , 1 , 1 , 1 )
1242
+ zero_vae_values = torch .tensor ([0.8642 , - 1.8583 , 0.1577 , 0.1350 , - 0.3641 , 2.5863 , - 1.9670 , 1.6065 , - 1.0475 , - 0.8678 , 1.1734 , - 1.8138 , - 1.5933 , - 0.7721 , - 0.3289 , - 1.3745 ]).view (1 , 16 , 1 , 1 , 1 )
1243
+ concat_latent [:, 4 :] = zero_vae_values
1244
+ concat_latent [:, 4 :, :1 ] = zero_vae_values_first
1245
+ concat_latent [:, 4 :, 1 :2 ] = zero_vae_values_second
1246
+ out ['c_concat' ] = comfy .conds .CONDNoiseShape (concat_latent )
1246
1247
if reference_latents is not None :
1247
1248
ref_latent = self .process_latent_in (reference_latents [- 1 ])
1248
1249
ref_latent_shape = list (ref_latent .shape )
0 commit comments