@@ -267,8 +267,22 @@ def _run_diffusion(
267
267
if is_schnell and self .control_lora :
268
268
raise ValueError ("Control LoRAs cannot be used with FLUX Schnell" )
269
269
270
- # Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
271
- img_cond = self ._prep_structural_control_img_cond (context )
270
+ # TODO(ryand): It's a bit confusing that we support inpainting via both FLUX Fill and masked image-to-image.
271
+ # Think about ways to tidy this interface, or at least add clear error messages when incompatible inputs are
272
+ # provided.
273
+
274
+ # Prepare the extra image conditioning tensor if either of the following are provided:
275
+ # - FLUX structural control image
276
+ # - FLUX Fill conditioning
277
+ img_cond : torch .Tensor | None = None
278
+ if self .control_lora is not None and self .fill_conditioning is not None :
279
+ raise ValueError ("Control LoRA and Fill conditioning cannot be used together." )
280
+ elif self .control_lora is not None :
281
+ img_cond = self ._prep_structural_control_img_cond (context )
282
+ elif self .fill_conditioning is not None :
283
+ img_cond = self ._prep_flux_fill_img_cond (
284
+ context , device = TorchDevice .choose_torch_device (), dtype = inference_dtype
285
+ )
272
286
273
287
inpaint_mask = self ._prep_inpaint_mask (context , x )
274
288
@@ -672,6 +686,56 @@ def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch
672
686
vae_info = context .models .load (self .controlnet_vae .vae )
673
687
return FluxVaeEncodeInvocation .vae_encode (vae_info = vae_info , image_tensor = img_cond )
674
688
689
+ def _prep_flux_fill_img_cond (
690
+ self , context : InvocationContext , device : torch .device , dtype : torch .dtype
691
+ ) -> torch .Tensor | None :
692
+ """Prepare the FLUX Fill conditioning.
693
+
694
+ This logic is based on:
695
+ https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
696
+ """
697
+ if self .fill_conditioning is None :
698
+ return None
699
+
700
+ # TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
701
+ if not self .controlnet_vae :
702
+ raise ValueError ("controlnet_vae must be set when using a FLUX Fill conditioning." )
703
+
704
+ # Load the conditioning image and resize it to the target image size.
705
+ cond_img = context .images .get_pil (self .fill_conditioning .image .image_name , mode = "RGB" )
706
+ cond_img = cond_img .resize ((self .width , self .height ), Image .Resampling .BICUBIC )
707
+ cond_img = np .array (cond_img )
708
+ cond_img = torch .from_numpy (cond_img ).float () / 127.5 - 1.0
709
+ cond_img = einops .rearrange (cond_img , "h w c -> 1 c h w" )
710
+ cond_img = cond_img .to (device = device , dtype = dtype )
711
+
712
+ # Load the mask and resize it to the target image size.
713
+ mask = context .tensors .load (self .fill_conditioning .mask .tensor_name )
714
+ assert mask .dtype == torch .bool
715
+ mask = mask .to (device = device , dtype = dtype )
716
+ mask = einops .rearrange (mask , "h w -> 1 1 h w" )
717
+
718
+ # Prepare image conditioning.
719
+ cond_img = cond_img * (1 - mask )
720
+ vae_info = context .models .load (self .controlnet_vae .vae )
721
+ cond_img = FluxVaeEncodeInvocation .vae_encode (vae_info = vae_info , image_tensor = cond_img )
722
+ cond_img = pack (cond_img )
723
+
724
+ # Prepare mask conditioning.
725
+ mask = mask [:, 0 , :, :]
726
+ # Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
727
+ mask = einops .rearrange (
728
+ mask ,
729
+ "b (h ph) (w pw) -> b (ph pw) h w" ,
730
+ ph = 8 ,
731
+ pw = 8 ,
732
+ )
733
+ mask = pack (mask )
734
+
735
+ # Merge image and mask conditioning.
736
+ img_cond = torch .cat ((cond_img , mask ), dim = - 1 )
737
+ return img_cond
738
+
675
739
def _normalize_ip_adapter_fields (self ) -> list [IPAdapterField ]:
676
740
if self .ip_adapter is None :
677
741
return []
0 commit comments