@@ -291,7 +291,6 @@ def _run_diffusion(
291
291
# Pack all latent tensors.
292
292
init_latents = pack (init_latents ) if init_latents is not None else None
293
293
inpaint_mask = pack (inpaint_mask ) if inpaint_mask is not None else None
294
- img_cond = pack (img_cond ) if img_cond is not None else None
295
294
noise = pack (noise )
296
295
x = pack (x )
297
296
@@ -663,13 +662,12 @@ def _prep_controlnet_extensions(
663
662
664
663
return controlnet_extensions
665
664
666
- def _prep_structural_control_img_cond (self , context : InvocationContext ) -> torch .Tensor | None :
667
- if self .control_lora is None :
668
- return None
669
-
665
+ def _prep_structural_control_img_cond (self , context : InvocationContext ) -> torch .Tensor :
670
666
if not self .controlnet_vae :
671
667
raise ValueError ("controlnet_vae must be set when using a FLUX Control LoRA." )
672
668
669
+ assert self .control_lora is not None
670
+
673
671
# Load the conditioning image and resize it to the target image size.
674
672
cond_img = context .images .get_pil (self .control_lora .img .image_name )
675
673
cond_img = cond_img .convert ("RGB" )
@@ -684,23 +682,24 @@ def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch
684
682
img_cond = einops .rearrange (img_cond , "h w c -> 1 c h w" )
685
683
686
684
vae_info = context .models .load (self .controlnet_vae .vae )
687
- return FluxVaeEncodeInvocation .vae_encode (vae_info = vae_info , image_tensor = img_cond )
685
+ img_cond = FluxVaeEncodeInvocation .vae_encode (vae_info = vae_info , image_tensor = img_cond )
686
+
687
+ return pack (img_cond )
688
688
689
689
def _prep_flux_fill_img_cond (
690
690
self , context : InvocationContext , device : torch .device , dtype : torch .dtype
691
- ) -> torch .Tensor | None :
691
+ ) -> torch .Tensor :
692
692
"""Prepare the FLUX Fill conditioning.
693
693
694
694
This logic is based on:
695
695
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
696
696
"""
697
- if self .fill_conditioning is None :
698
- return None
699
-
700
697
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
701
698
if not self .controlnet_vae :
702
699
raise ValueError ("controlnet_vae must be set when using a FLUX Fill conditioning." )
703
700
701
+ assert self .fill_conditioning is not None
702
+
704
703
# Load the conditioning image and resize it to the target image size.
705
704
cond_img = context .images .get_pil (self .fill_conditioning .image .image_name , mode = "RGB" )
706
705
cond_img = cond_img .resize ((self .width , self .height ), Image .Resampling .BICUBIC )
@@ -711,9 +710,13 @@ def _prep_flux_fill_img_cond(
711
710
712
711
# Load the mask and resize it to the target image size.
713
712
mask = context .tensors .load (self .fill_conditioning .mask .tensor_name )
713
+ # We expect mask to be a bool tensor with shape [1, H, W].
714
714
assert mask .dtype == torch .bool
715
+ assert mask .dim () == 3
716
+ assert mask .shape [0 ] == 1
717
+ mask = tv_resize (mask , size = [self .height , self .width ], interpolation = tv_transforms .InterpolationMode .NEAREST )
715
718
mask = mask .to (device = device , dtype = dtype )
716
- mask = einops .rearrange (mask , "h w -> 1 1 h w" )
719
+ mask = einops .rearrange (mask , "1 h w -> 1 1 h w" )
717
720
718
721
# Prepare image conditioning.
719
722
cond_img = cond_img * (1 - mask )
@@ -724,12 +727,7 @@ def _prep_flux_fill_img_cond(
724
727
# Prepare mask conditioning.
725
728
mask = mask [:, 0 , :, :]
726
729
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
727
- mask = einops .rearrange (
728
- mask ,
729
- "b (h ph) (w pw) -> b (ph pw) h w" ,
730
- ph = 8 ,
731
- pw = 8 ,
732
- )
730
+ mask = einops .rearrange (mask , "b (h ph) (w pw) -> b (ph pw) h w" , ph = 8 , pw = 8 )
733
731
mask = pack (mask )
734
732
735
733
# Merge image and mask conditioning.
0 commit comments