11import torch
2+ import torch .nn .functional as F
23import torchvision .transforms as T
34from einops import repeat
4- from PIL import Image
55
66from invokeai .app .invocations .fields import FluxKontextConditioningField
77from invokeai .app .invocations .model import VAEField
88from invokeai .app .services .shared .invocation_context import InvocationContext
99from invokeai .backend .flux .modules .autoencoder import AutoEncoder
1010from invokeai .backend .flux .sampling_utils import pack
11- from invokeai .backend .flux .util import PREFERED_KONTEXT_RESOLUTIONS
1211from invokeai .backend .util .devices import TorchDevice
1312
1413
@@ -115,29 +114,10 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
115114 for idx , kontext_field in enumerate (self .kontext_conditioning ):
116115 image = self ._context .images .get_pil (kontext_field .image .image_name )
117116
118- # Calculate aspect ratio of input image
119- width , height = image .size
120- aspect_ratio = width / height
121-
122- # Find the closest preferred resolution by aspect ratio
123- _ , target_width , target_height = min (
124- ((abs (aspect_ratio - w / h ), w , h ) for w , h in PREFERED_KONTEXT_RESOLUTIONS ), key = lambda x : x [0 ]
125- )
126-
127- # Apply BFL's scaling formula
128- # This ensures compatibility with the model's training
129- scaled_width = 2 * int (target_width / 16 )
130- scaled_height = 2 * int (target_height / 16 )
131-
132- # Resize to the exact resolution used during training
117+ # Convert to RGB
133118 image = image .convert ("RGB" )
134- final_width = 8 * scaled_width
135- final_height = 8 * scaled_height
136- # Use BICUBIC for smoother resizing to reduce artifacts
137- image = image .resize ((final_width , final_height ), Image .Resampling .BICUBIC )
138119
139120 # Convert to tensor using torchvision transforms for consistency
140- # This matches the normalization used in image_resized_to_grid_as_tensor
141121 transformation = T .Compose (
142122 [
143123 T .ToTensor (), # Converts PIL image to tensor and scales to [0, 1]
@@ -161,6 +141,15 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
161141 # Extract tensor dimensions
162142 batch_size , _ , latent_height , latent_width = kontext_latents_unpacked .shape
163143
144+ # Pad latents to be compatible with patch_size=2
145+ # This ensures dimensions are even for the pack() function
146+ pad_h = (2 - latent_height % 2 ) % 2
147+ pad_w = (2 - latent_width % 2 ) % 2
148+ if pad_h > 0 or pad_w > 0 :
149+ kontext_latents_unpacked = F .pad (kontext_latents_unpacked , (0 , pad_w , 0 , pad_h ), mode = "circular" )
150+ # Update dimensions after padding
151+ _ , _ , latent_height , latent_width = kontext_latents_unpacked .shape
152+
164153 # Pack the latents
165154 kontext_latents_packed = pack (kontext_latents_unpacked ).to (self ._device , self ._dtype )
166155
0 commit comments