Skip to content

Commit 9b2970a

Browse files
committed
add flux fill pipeline
1 parent e564abe commit 9b2970a

File tree

12 files changed

+993
-11
lines changed

12 files changed

+993
-11
lines changed

scripts/convert_flux_to_diffusers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,17 @@ def main(args):
279279
num_single_layers = 38
280280
inner_dim = 3072
281281
mlp_ratio = 4.0
282+
283+
# dev has 64, dev-fill has 384
284+
in_channels = original_ckpt["img_in.weight"].shape[1]
285+
out_channels = 64
286+
282287
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
283288
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
284289
)
285-
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
290+
transformer = FluxTransformer2DModel(
291+
guidance_embeds=has_guidance, in_channels=in_channels, out_channels=out_channels
292+
)
286293
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
287294

288295
print(

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@
272272
"FluxControlNetImg2ImgPipeline",
273273
"FluxControlNetInpaintPipeline",
274274
"FluxControlNetPipeline",
275+
"FluxFillPipeline",
275276
"FluxImg2ImgPipeline",
276277
"FluxInpaintPipeline",
277278
"FluxPipeline",
@@ -737,6 +738,7 @@
737738
FluxControlNetImg2ImgPipeline,
738739
FluxControlNetInpaintPipeline,
739740
FluxControlNetPipeline,
741+
FluxFillPipeline,
740742
FluxImg2ImgPipeline,
741743
FluxInpaintPipeline,
742744
FluxPipeline,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(
238238
self,
239239
patch_size: int = 1,
240240
in_channels: int = 64,
241+
out_channels: int = None,
241242
num_layers: int = 19,
242243
num_single_layers: int = 38,
243244
attention_head_dim: int = 128,
@@ -248,7 +249,10 @@ def __init__(
248249
axes_dims_rope: Tuple[int] = (16, 56, 56),
249250
):
250251
super().__init__()
251-
self.out_channels = in_channels
252+
if out_channels is None:
253+
self.out_channels = in_channels
254+
else:
255+
self.out_channels = out_channels
252256
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
253257

254258
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
"FluxImg2ImgPipeline",
134134
"FluxInpaintPipeline",
135135
"FluxPipeline",
136+
"FluxFillPipeline",
136137
]
137138
_import_structure["audioldm"] = ["AudioLDMPipeline"]
138139
_import_structure["audioldm2"] = [
@@ -524,6 +525,7 @@
524525
FluxControlNetImg2ImgPipeline,
525526
FluxControlNetInpaintPipeline,
526527
FluxControlNetPipeline,
528+
FluxFillPipeline,
527529
FluxImg2ImgPipeline,
528530
FluxInpaintPipeline,
529531
FluxPipeline,

src/diffusers/pipelines/flux/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
2727
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
2828
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
29+
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
2930
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
3031
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
3132
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -39,6 +40,7 @@
3940
from .pipeline_flux_controlnet import FluxControlNetPipeline
4041
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
4142
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
43+
from .pipeline_flux_fill import FluxFillPipeline
4244
from .pipeline_flux_img2img import FluxImg2ImgPipeline
4345
from .pipeline_flux_inpaint import FluxInpaintPipeline
4446
else:

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def prepare_latents(
513513
shape = (batch_size, num_channels_latents, height, width)
514514

515515
if latents is not None:
516-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
516+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
517517
return latents.to(device=device, dtype=dtype), latent_image_ids
518518

519519
if isinstance(generator, list) and len(generator) != batch_size:

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def calculate_shift(
9797
return mu
9898

9999

100+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101+
def retrieve_latents(
102+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103+
):
104+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105+
return encoder_output.latent_dist.sample(generator)
106+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107+
return encoder_output.latent_dist.mode()
108+
elif hasattr(encoder_output, "latents"):
109+
return encoder_output.latents
110+
else:
111+
raise AttributeError("Could not access latents of provided encoder_output")
112+
113+
100114
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101115
def retrieve_timesteps(
102116
scheduler,
@@ -512,7 +526,7 @@ def prepare_latents(
512526
shape = (batch_size, num_channels_latents, height, width)
513527

514528
if latents is not None:
515-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
529+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
516530
return latents.to(device=device, dtype=dtype), latent_image_ids
517531

518532
if isinstance(generator, list) and len(generator) != batch_size:
@@ -772,7 +786,7 @@ def __call__(
772786
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
773787
if self.controlnet.input_hint_block is None:
774788
# vae encode
775-
control_image = self.vae.encode(control_image).latent_dist.sample()
789+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
776790
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
777791

778792
# pack
@@ -810,7 +824,7 @@ def __call__(
810824

811825
if self.controlnet.nets[0].input_hint_block is None:
812826
# vae encode
813-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
827+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
814828
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
815829

816830
# pack

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def __call__(
801801
)
802802
height, width = control_image.shape[-2:]
803803

804-
control_image = self.vae.encode(control_image).latent_dist.sample()
804+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
805805
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
806806

807807
height_control_image, width_control_image = control_image.shape[2:]
@@ -832,7 +832,7 @@ def __call__(
832832
)
833833
height, width = control_image_.shape[-2:]
834834

835-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
835+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
836836
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
837837

838838
height_control_image, width_control_image = control_image_.shape[2:]

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def __call__(
942942
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
943943
if self.controlnet.input_hint_block is None:
944944
# vae encode
945-
control_image = self.vae.encode(control_image).latent_dist.sample()
945+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
946946
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
947947

948948
# pack
@@ -979,7 +979,7 @@ def __call__(
979979

980980
if self.controlnet.nets[0].input_hint_block is None:
981981
# vae encode
982-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
982+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
983983
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
984984

985985
# pack

0 commit comments

Comments
 (0)