Skip to content

Commit 5ea3ec5

Browse files
RyanJDickpsychedelicious
authored andcommitted
Get FLUX Fill working. Note: To use FLUX Fill, set guidance to ~30.
1 parent f13a07b commit 5ea3ec5

File tree

3 files changed

+42
-26
lines changed

3 files changed

+42
-26
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def _run_diffusion(
291291
# Pack all latent tensors.
292292
init_latents = pack(init_latents) if init_latents is not None else None
293293
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
294-
img_cond = pack(img_cond) if img_cond is not None else None
295294
noise = pack(noise)
296295
x = pack(x)
297296

@@ -663,13 +662,12 @@ def _prep_controlnet_extensions(
663662

664663
return controlnet_extensions
665664

666-
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
667-
if self.control_lora is None:
668-
return None
669-
665+
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor:
670666
if not self.controlnet_vae:
671667
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
672668

669+
assert self.control_lora is not None
670+
673671
# Load the conditioning image and resize it to the target image size.
674672
cond_img = context.images.get_pil(self.control_lora.img.image_name)
675673
cond_img = cond_img.convert("RGB")
@@ -684,23 +682,24 @@ def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch
684682
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
685683

686684
vae_info = context.models.load(self.controlnet_vae.vae)
687-
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
685+
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
686+
687+
return pack(img_cond)
688688

689689
def _prep_flux_fill_img_cond(
690690
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
691-
) -> torch.Tensor | None:
691+
) -> torch.Tensor:
692692
"""Prepare the FLUX Fill conditioning.
693693
694694
This logic is based on:
695695
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
696696
"""
697-
if self.fill_conditioning is None:
698-
return None
699-
700697
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
701698
if not self.controlnet_vae:
702699
raise ValueError("controlnet_vae must be set when using a FLUX Fill conditioning.")
703700

701+
assert self.fill_conditioning is not None
702+
704703
# Load the conditioning image and resize it to the target image size.
705704
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
706705
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
@@ -711,9 +710,13 @@ def _prep_flux_fill_img_cond(
711710

712711
# Load the mask and resize it to the target image size.
713712
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
713+
# We expect mask to be a bool tensor with shape [1, H, W].
714714
assert mask.dtype == torch.bool
715+
assert mask.dim() == 3
716+
assert mask.shape[0] == 1
717+
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
715718
mask = mask.to(device=device, dtype=dtype)
716-
mask = einops.rearrange(mask, "h w -> 1 1 h w")
719+
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
717720

718721
# Prepare image conditioning.
719722
cond_img = cond_img * (1 - mask)
@@ -724,12 +727,7 @@ def _prep_flux_fill_img_cond(
724727
# Prepare mask conditioning.
725728
mask = mask[:, 0, :, :]
726729
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
727-
mask = einops.rearrange(
728-
mask,
729-
"b (h ph) (w pw) -> b (ph pw) h w",
730-
ph=8,
731-
pw=8,
732-
)
730+
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
733731
mask = pack(mask)
734732

735733
# Merge image and mask conditioning.

invokeai/backend/flux/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class ModelSpec:
2020

2121
max_seq_lengths: Dict[str, Literal[256, 512]] = {
2222
"flux-dev": 512,
23+
"flux-dev-fill": 512,
2324
"flux-schnell": 256,
2425
}
2526

@@ -68,4 +69,19 @@ class ModelSpec:
6869
qkv_bias=True,
6970
guidance_embed=False,
7071
),
72+
"flux-dev-fill": FluxParams(
73+
in_channels=384,
74+
out_channels=64,
75+
vec_in_dim=768,
76+
context_in_dim=4096,
77+
hidden_size=3072,
78+
mlp_ratio=4.0,
79+
num_heads=24,
80+
depth=19,
81+
depth_single_blocks=38,
82+
axes_dim=[16, 56, 56],
83+
theta=10_000,
84+
qkv_bias=True,
85+
guidance_embed=True,
86+
),
7187
}

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -417,20 +417,22 @@ def _get_checkpoint_config_path(
417417
# TODO: Decide between dev/schnell
418418
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
419419
state_dict = checkpoint.get("state_dict") or checkpoint
420+
421+
# HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model
422+
# loading. When FLUX support was first added, it was decided that this was the easiest way to support
423+
# the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in
424+
# the future.
420425
if (
421426
"guidance_in.out_layer.weight" in state_dict
422427
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
423428
):
424-
# For flux, this is a key in invokeai.backend.flux.util.params
425-
# Due to model type and format being the descriminator for model configs this
426-
# is used rather than attempting to support flux with separate model types and format
427-
# If changed in the future, please fix me
428-
config_file = "flux-dev"
429+
if variant_type == ModelVariantType.Normal:
430+
config_file = "flux-dev"
431+
elif variant_type == ModelVariantType.Inpaint:
432+
config_file = "flux-dev-fill"
433+
else:
434+
raise ValueError(f"Unexpected FLUX variant type: {variant_type}")
429435
else:
430-
# For flux, this is a key in invokeai.backend.flux.util.params
431-
# Due to model type and format being the discriminator for model configs this
432-
# is used rather than attempting to support flux with separate model types and format
433-
# If changed in the future, please fix me
434436
config_file = "flux-schnell"
435437
else:
436438
config_file = LEGACY_CONFIGS[base_type][variant_type]

0 commit comments

Comments
 (0)