|
4 | 4 | from PIL import Image |
5 | 5 |
|
6 | 6 | from invokeai.app.invocations.fields import FluxKontextConditioningField |
7 | | -from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation |
8 | 7 | from invokeai.app.invocations.model import VAEField |
9 | 8 | from invokeai.app.services.shared.invocation_context import InvocationContext |
| 9 | +from invokeai.backend.flux.modules.autoencoder import AutoEncoder |
10 | 10 | from invokeai.backend.flux.sampling_utils import pack |
11 | 11 | from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS |
| 12 | +from invokeai.backend.util.devices import TorchDevice |
12 | 13 |
|
13 | 14 |
|
14 | 15 | def generate_img_ids_with_offset( |
@@ -149,7 +150,13 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]: |
149 | 150 | image_tensor = image_tensor.to(self._device) |
150 | 151 |
|
151 | 152 | # Continue with VAE encoding |
152 | | - kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor) |
| 153 | + # Don't sample from the distribution for reference images - use the mean (matching ComfyUI) |
| 154 | + with vae_info as vae: |
| 155 | + assert isinstance(vae, AutoEncoder) |
| 156 | + vae_dtype = next(iter(vae.parameters())).dtype |
| 157 | + image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) |
| 158 | + # Use sample=False to get the distribution mean without noise |
| 159 | + kontext_latents_unpacked = vae.encode(image_tensor, sample=False) |
153 | 160 |
|
154 | 161 | # Extract tensor dimensions |
155 | 162 | batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape |
|
0 commit comments