diff --git a/comfy/model_base.py b/comfy/model_base.py index b0b9cde7d087..40b7b375723f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1227,22 +1227,23 @@ def extra_conds(self, **kwargs): if audio_embed is not None: out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) - if "c_concat" not in out: # 1.7B model - reference_latents = kwargs.get("reference_latents", None) - if reference_latents is not None: - out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + reference_latents = kwargs.get("reference_latents", None) + + if "c_concat" not in out and reference_latents is not None and reference_latents[0].shape[1] == 16: # 1.7B model + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) else: - noise_shape = list(noise.shape) - noise_shape[1] += 4 - concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) - zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1) - zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1) - zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1) - concat_latent[:, 4:] = zero_vae_values - concat_latent[:, 4:, :1] = zero_vae_values_first - concat_latent[:, 4:, 1:2] = zero_vae_values_second - out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent) - reference_latents = kwargs.get("reference_latents", None) + concat_latent_image = kwargs.get("concat_latent_image", None) + if concat_latent_image is None: + noise_shape = list(noise.shape) + noise_shape[1] += 4 + concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1) + zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1) + zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1) + concat_latent[:, 4:] = zero_vae_values + concat_latent[:, 4:, :1] = zero_vae_values_first + concat_latent[:, 4:, 1:2] = zero_vae_values_second + out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent) if reference_latents is not None: ref_latent = self.process_latent_in(reference_latents[-1]) ref_latent_shape = list(ref_latent.shape)