Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/convert_flux_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--in_channels", type=int, default=64)
parser.add_argument("--out_channels", type=int, default=None)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
Expand Down Expand Up @@ -279,10 +281,13 @@ def main(args):
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
transformer = FluxTransformer2DModel(
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

print(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
Expand Down Expand Up @@ -737,6 +738,7 @@
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
Expand All @@ -248,7 +249,7 @@ def __init__(
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
Expand All @@ -261,7 +262,7 @@ def __init__(
)

self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)

self.transformer_blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -449,13 +450,15 @@ def forward(
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None

temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
"FluxFillPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
Expand Down Expand Up @@ -524,6 +525,7 @@
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -39,6 +40,7 @@
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
else:
Expand Down
86 changes: 82 additions & 4 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
Expand Down Expand Up @@ -529,6 +529,41 @@ def prepare_latents(

return latents, latent_image_ids

# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way is probably to have 3 separate FluxControlPipeline? but maybe @DN6 can come up with better ideas!

def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)

image_batch_size = image.shape[0]

if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt

image = image.repeat_interleave(repeat_by, dim=0)

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)

return image

@property
def guidance_scale(self):
return self._guidance_scale
Expand Down Expand Up @@ -556,9 +591,11 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
control_image: PipelineImageInput = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
control_latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -595,6 +632,14 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -667,6 +712,7 @@ def __call__(

device = self._execution_device

# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
Expand All @@ -686,7 +732,35 @@ def __call__(
)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
num_channels_latents = (
self.transformer.config.in_channels // 4
if control_image is None
else self.transformer.config.in_channels // 8
)

if control_image is not None and control_latents is None:
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)

control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator)
control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

height_control_image, width_control_image = control_latents.shape[2:]
control_latents = self._pack_latents(
control_latents,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand Down Expand Up @@ -732,11 +806,16 @@ def __call__(
if self.interrupt:
continue

if control_latents is not None:
latent_model_input = torch.cat([latents, control_latents], dim=2)
else:
latent_model_input = latents

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latents,
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down Expand Up @@ -774,7 +853,6 @@ def __call__(

if output_type == "latent":
image = latents

else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def __call__(
device = self._execution_device
dtype = self.transformer.dtype

# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
Expand Down
Loading
Loading