Skip to content

Commit 9c9e64e

Browse files
committed
up
1 parent 45286fd commit 9c9e64e

File tree

7 files changed

+61
-20
lines changed

7 files changed

+61
-20
lines changed

scripts/convert_flux_to_diffusers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,17 @@ def main(args):
279279
num_single_layers = 38
280280
inner_dim = 3072
281281
mlp_ratio = 4.0
282-
282+
283283
# dev has 64, dev-fill has 384
284284
in_channels = original_ckpt["img_in.weight"].shape[1]
285285
out_channels = 64
286286

287287
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
288288
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
289289
)
290-
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance, in_channels=in_channels, out_channels=out_channels)
290+
transformer = FluxTransformer2DModel(
291+
guidance_embeds=has_guidance, in_channels=in_channels, out_channels=out_channels
292+
)
291293
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
292294

293295
print(

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,10 @@
272272
"FluxControlNetImg2ImgPipeline",
273273
"FluxControlNetInpaintPipeline",
274274
"FluxControlNetPipeline",
275+
"FluxFillPipeline",
275276
"FluxImg2ImgPipeline",
276277
"FluxInpaintPipeline",
277278
"FluxPipeline",
278-
"FluxFillPipeline",
279279
"HunyuanDiTControlNetPipeline",
280280
"HunyuanDiTPAGPipeline",
281281
"HunyuanDiTPipeline",
@@ -738,10 +738,10 @@
738738
FluxControlNetImg2ImgPipeline,
739739
FluxControlNetInpaintPipeline,
740740
FluxControlNetPipeline,
741+
FluxFillPipeline,
741742
FluxImg2ImgPipeline,
742743
FluxInpaintPipeline,
743744
FluxPipeline,
744-
FluxFillPipeline,
745745
HunyuanDiTControlNetPipeline,
746746
HunyuanDiTPAGPipeline,
747747
HunyuanDiTPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,10 @@
525525
FluxControlNetImg2ImgPipeline,
526526
FluxControlNetInpaintPipeline,
527527
FluxControlNetPipeline,
528+
FluxFillPipeline,
528529
FluxImg2ImgPipeline,
529530
FluxInpaintPipeline,
530531
FluxPipeline,
531-
FluxFillPipeline,
532532
)
533533
from .hunyuandit import HunyuanDiTPipeline
534534
from .i2vgen_xl import I2VGenXLPipeline

src/diffusers/pipelines/flux/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
2727
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
2828
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
29+
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
2930
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
3031
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
31-
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
3232
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3333
try:
3434
if not (is_transformers_available() and is_torch_available()):
@@ -40,9 +40,9 @@
4040
from .pipeline_flux_controlnet import FluxControlNetPipeline
4141
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
4242
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
43+
from .pipeline_flux_fill import FluxFillPipeline
4344
from .pipeline_flux_img2img import FluxImg2ImgPipeline
4445
from .pipeline_flux_inpaint import FluxInpaintPipeline
45-
from .pipeline_flux_fill import FluxFillPipeline
4646
else:
4747
import sys
4848

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def prepare_latents(
512512
shape = (batch_size, num_channels_latents, height, width)
513513

514514
if latents is not None:
515-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
515+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
516516
return latents.to(device=device, dtype=dtype), latent_image_ids
517517

518518
if isinstance(generator, list) and len(generator) != batch_size:

src/diffusers/pipelines/flux/pipeline_flux_fill.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"""
6565

6666

67+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
6768
def calculate_shift(
6869
image_seq_len,
6970
base_seq_len: int = 256,
@@ -136,6 +137,7 @@ def retrieve_timesteps(
136137
timesteps = scheduler.timesteps
137138
return timesteps, num_inference_steps
138139

140+
139141
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
140142
def retrieve_latents(
141143
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -226,6 +228,7 @@ def __init__(
226228
)
227229
self.default_sample_size = 128
228230

231+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
229232
def _get_t5_prompt_embeds(
230233
self,
231234
prompt: Union[str, List[str]] = None,
@@ -275,6 +278,7 @@ def _get_t5_prompt_embeds(
275278

276279
return prompt_embeds
277280

281+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
278282
def _get_clip_prompt_embeds(
279283
self,
280284
prompt: Union[str, List[str]],
@@ -318,7 +322,7 @@ def _get_clip_prompt_embeds(
318322
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
319323

320324
return prompt_embeds
321-
325+
322326
def prepare_mask_latents(
323327
self,
324328
mask,
@@ -364,7 +368,7 @@ def prepare_mask_latents(
364368
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
365369

366370
# prepare mask for latents
367-
mask = mask[:,0,:,:]
371+
mask = mask[:, 0, :, :]
368372
mask = mask.view(batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor)
369373
mask = mask.permute(0, 2, 4, 1, 3)
370374
mask = mask.reshape(batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width)
@@ -390,7 +394,7 @@ def prepare_mask_latents(
390394

391395
return mask, masked_image_latents
392396

393-
397+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
394398
def encode_prompt(
395399
self,
396400
prompt: Union[str, List[str]],
@@ -470,6 +474,7 @@ def encode_prompt(
470474

471475
return prompt_embeds, pooled_prompt_embeds, text_ids
472476

477+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
473478
def check_inputs(
474479
self,
475480
prompt,
@@ -521,6 +526,7 @@ def check_inputs(
521526
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
522527

523528
@staticmethod
529+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
524530
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
525531
latent_image_ids = torch.zeros(height, width, 3)
526532
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
@@ -535,6 +541,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
535541
return latent_image_ids.to(device=device, dtype=dtype)
536542

537543
@staticmethod
544+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
538545
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
539546
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
540547
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -543,6 +550,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
543550
return latents
544551

545552
@staticmethod
553+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
546554
def _unpack_latents(latents, height, width, vae_scale_factor):
547555
batch_size, num_patches, channels = latents.shape
548556

@@ -587,6 +595,7 @@ def disable_vae_tiling(self):
587595
"""
588596
self.vae.disable_tiling()
589597

598+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
590599
def prepare_latents(
591600
self,
592601
batch_size,
@@ -644,6 +653,9 @@ def __call__(
644653
self,
645654
prompt: Union[str, List[str]] = None,
646655
prompt_2: Optional[Union[str, List[str]]] = None,
656+
image: Optional[torch.FloatTensor] = None,
657+
mask_image: Optional[torch.FloatTensor] = None,
658+
masked_image_latents: Optional[torch.FloatTensor] = None,
647659
height: Optional[int] = None,
648660
width: Optional[int] = None,
649661
num_inference_steps: int = 28,
@@ -660,9 +672,6 @@ def __call__(
660672
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
661673
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
662674
max_sequence_length: int = 512,
663-
img_cond: Optional[torch.FloatTensor] = None,
664-
image: Optional[torch.FloatTensor] = None,
665-
mask_image: Optional[torch.FloatTensor] = None,
666675
):
667676
r"""
668677
Function invoked when calling the pipeline for generation.
@@ -674,6 +683,22 @@ def __call__(
674683
prompt_2 (`str` or `List[str]`, *optional*):
675684
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
676685
will be used instead
686+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
687+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
688+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
689+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
690+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
691+
latents as `image`, but if passing latents directly it is not encoded again.
692+
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
693+
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
694+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
695+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
696+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
697+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
698+
1)`, or `(H, W)`.
699+
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
700+
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
701+
latents tensor will ge generated by `mask_image`.
677702
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
678703
The height in pixels of the generated image. This is set to 1024 by default for the best results.
679704
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -794,18 +819,17 @@ def __call__(
794819
latents,
795820
)
796821

797-
if img_cond is not None:
798-
img_cond = img_cond.to(latents.device)
822+
if masked_image_latents is not None:
823+
masked_image_latents = masked_image_latents.to(latents.device)
799824
else:
800-
801825
if image is not None and mask_image is not None:
802826
image = self.image_processor.preprocess(image)
803827
mask_image = self.mask_processor.preprocess(mask_image)
804828
masked_image = image * (1 - mask_image)
805829
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
806830

807831
height, width = image.shape[-2:]
808-
832+
809833
mask, masked_image_latents = self.prepare_mask_latents(
810834
mask_image,
811835
masked_image,
@@ -818,7 +842,7 @@ def __call__(
818842
device,
819843
generator,
820844
)
821-
img_cond = torch.cat((masked_image_latents, mask), dim=-1)
845+
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
822846

823847
# 5. Prepare timesteps
824848
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
@@ -858,7 +882,7 @@ def __call__(
858882
timestep = t.expand(latents.shape[0]).to(latents.dtype)
859883

860884
noise_pred = self.transformer(
861-
hidden_states=torch.cat((latents, img_cond), dim=2) if img_cond is not None else latents,
885+
hidden_states=torch.cat((latents, masked_image_latents), dim=2),
862886
timestep=timestep / 1000,
863887
guidance=guidance,
864888
pooled_projections=pooled_prompt_embeds,

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,21 @@ def from_pretrained(cls, *args, **kwargs):
422422
requires_backends(cls, ["torch", "transformers"])
423423

424424

425+
class FluxFillPipeline(metaclass=DummyObject):
426+
_backends = ["torch", "transformers"]
427+
428+
def __init__(self, *args, **kwargs):
429+
requires_backends(self, ["torch", "transformers"])
430+
431+
@classmethod
432+
def from_config(cls, *args, **kwargs):
433+
requires_backends(cls, ["torch", "transformers"])
434+
435+
@classmethod
436+
def from_pretrained(cls, *args, **kwargs):
437+
requires_backends(cls, ["torch", "transformers"])
438+
439+
425440
class FluxImg2ImgPipeline(metaclass=DummyObject):
426441
_backends = ["torch", "transformers"]
427442

0 commit comments

Comments
 (0)