Skip to content

Commit f13a07b

Browse files
RyanJDickpsychedelicious
authored andcommitted
WIP on updating FluxDenoise to support FLUX Fill.
1 parent a913f01 commit f13a07b

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,22 @@ def _run_diffusion(
267267
if is_schnell and self.control_lora:
268268
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
269269

270-
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
271-
img_cond = self._prep_structural_control_img_cond(context)
270+
# TODO(ryand): It's a bit confusing that we support inpainting via both FLUX Fill and masked image-to-image.
271+
# Think about ways to tidy this interface, or at least add clear error messages when incompatible inputs are
272+
# provided.
273+
274+
# Prepare the extra image conditioning tensor if either of the following are provided:
275+
# - FLUX structural control image
276+
# - FLUX Fill conditioning
277+
img_cond: torch.Tensor | None = None
278+
if self.control_lora is not None and self.fill_conditioning is not None:
279+
raise ValueError("Control LoRA and Fill conditioning cannot be used together.")
280+
elif self.control_lora is not None:
281+
img_cond = self._prep_structural_control_img_cond(context)
282+
elif self.fill_conditioning is not None:
283+
img_cond = self._prep_flux_fill_img_cond(
284+
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
285+
)
272286

273287
inpaint_mask = self._prep_inpaint_mask(context, x)
274288

@@ -672,6 +686,56 @@ def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch
672686
vae_info = context.models.load(self.controlnet_vae.vae)
673687
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
674688

689+
def _prep_flux_fill_img_cond(
690+
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
691+
) -> torch.Tensor | None:
692+
"""Prepare the FLUX Fill conditioning.
693+
694+
This logic is based on:
695+
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
696+
"""
697+
if self.fill_conditioning is None:
698+
return None
699+
700+
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
701+
if not self.controlnet_vae:
702+
raise ValueError("controlnet_vae must be set when using a FLUX Fill conditioning.")
703+
704+
# Load the conditioning image and resize it to the target image size.
705+
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
706+
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
707+
cond_img = np.array(cond_img)
708+
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
709+
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
710+
cond_img = cond_img.to(device=device, dtype=dtype)
711+
712+
# Load the mask and resize it to the target image size.
713+
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
714+
assert mask.dtype == torch.bool
715+
mask = mask.to(device=device, dtype=dtype)
716+
mask = einops.rearrange(mask, "h w -> 1 1 h w")
717+
718+
# Prepare image conditioning.
719+
cond_img = cond_img * (1 - mask)
720+
vae_info = context.models.load(self.controlnet_vae.vae)
721+
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
722+
cond_img = pack(cond_img)
723+
724+
# Prepare mask conditioning.
725+
mask = mask[:, 0, :, :]
726+
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
727+
mask = einops.rearrange(
728+
mask,
729+
"b (h ph) (w pw) -> b (ph pw) h w",
730+
ph=8,
731+
pw=8,
732+
)
733+
mask = pack(mask)
734+
735+
# Merge image and mask conditioning.
736+
img_cond = torch.cat((cond_img, mask), dim=-1)
737+
return img_cond
738+
675739
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
676740
if self.ip_adapter is None:
677741
return []

0 commit comments

Comments
 (0)