Skip to content

Commit e522de3

Browse files
refactor(nodes): roll back latent-space resizing of kontext images
1 parent d591b50 commit e522de3

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

invokeai/backend/flux/extensions/kontext_extension.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import torch
2+
import torch.nn.functional as F
23
import torchvision.transforms as T
34
from einops import repeat
4-
from PIL import Image
55

66
from invokeai.app.invocations.fields import FluxKontextConditioningField
77
from invokeai.app.invocations.model import VAEField
88
from invokeai.app.services.shared.invocation_context import InvocationContext
99
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
1010
from invokeai.backend.flux.sampling_utils import pack
11-
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
1211
from invokeai.backend.util.devices import TorchDevice
1312

1413

@@ -115,29 +114,10 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
115114
for idx, kontext_field in enumerate(self.kontext_conditioning):
116115
image = self._context.images.get_pil(kontext_field.image.image_name)
117116

118-
# Calculate aspect ratio of input image
119-
width, height = image.size
120-
aspect_ratio = width / height
121-
122-
# Find the closest preferred resolution by aspect ratio
123-
_, target_width, target_height = min(
124-
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
125-
)
126-
127-
# Apply BFL's scaling formula
128-
# This ensures compatibility with the model's training
129-
scaled_width = 2 * int(target_width / 16)
130-
scaled_height = 2 * int(target_height / 16)
131-
132-
# Resize to the exact resolution used during training
117+
# Convert to RGB
133118
image = image.convert("RGB")
134-
final_width = 8 * scaled_width
135-
final_height = 8 * scaled_height
136-
# Use BICUBIC for smoother resizing to reduce artifacts
137-
image = image.resize((final_width, final_height), Image.Resampling.BICUBIC)
138119

139120
# Convert to tensor using torchvision transforms for consistency
140-
# This matches the normalization used in image_resized_to_grid_as_tensor
141121
transformation = T.Compose(
142122
[
143123
T.ToTensor(), # Converts PIL image to tensor and scales to [0, 1]
@@ -161,6 +141,15 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
161141
# Extract tensor dimensions
162142
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
163143

144+
# Pad latents to be compatible with patch_size=2
145+
# This ensures dimensions are even for the pack() function
146+
pad_h = (2 - latent_height % 2) % 2
147+
pad_w = (2 - latent_width % 2) % 2
148+
if pad_h > 0 or pad_w > 0:
149+
kontext_latents_unpacked = F.pad(kontext_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
150+
# Update dimensions after padding
151+
_, _, latent_height, latent_width = kontext_latents_unpacked.shape
152+
164153
# Pack the latents
165154
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
166155

0 commit comments

Comments
 (0)