Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
393 changes: 393 additions & 0 deletions scripts/convert_flux_control_lora_to_diffusers.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion scripts/convert_flux_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--in_channels", type=int, default=64)
parser.add_argument("--out_channels", type=int, default=None)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
Expand Down Expand Up @@ -279,10 +281,13 @@ def main(args):
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0

converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
transformer = FluxTransformer2DModel(
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

print(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
Expand Down Expand Up @@ -737,6 +738,7 @@
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
Expand Down
54 changes: 49 additions & 5 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,14 +1787,41 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
has_lora_keys = any("lora" in key for key in state_dict.keys())

# Flux Control LoRAs also have norm keys
supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys)
Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For supporting the additional norm layers. Also FYI, the norm layers from the LoRA are the exact same numerically to Flux1-Canny-Dev and Flux1-Depth-Dev, but different from Flux1-Dev (the model for which the lora is intended), so we cannot do without this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Thanks for confirming!


if not (has_lora_keys or has_norm_keys):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better and more robust check could be to see if all the keys have either "lora" or the supported norm key prefixes. WDYT?

raise ValueError("Invalid LoRA checkpoint.")

transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
if len(transformer_state_dict) > 0:
def prune_state_dict_(state_dict):
pruned_keys = []
for key in list(state_dict.keys()):
is_lora_key_present = "lora" in key
is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys)
if not is_lora_key_present and not is_norm_key_present:
state_dict.pop(key)
pruned_keys.append(key)
return pruned_keys

pruned_keys = prune_state_dict_(state_dict)
if len(pruned_keys) > 0:
logger.warning(
f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be tested as well (can be handled during either nearing merge or in a separate PR).


transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k}
transformer_norm_state_dict = {
k: v
for k, v in state_dict.items()
if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys)
}
Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this is self-explanatory. load_lora_adapter is incompatible for anything without lora keys, so we separate the state dict into the norm and lora dicts.

We also remove any other layers (this was just for sanity checking that I was doing things correctly) if they are incompatible while raising a warning if there are any additional keys (there are none at the moment, but good to have IMO).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!

Should we instead pop from the state_dict and raise a warning and error if there's anything remaining inside it? T


if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer(
state_dict,
transformer_lora_state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
Expand All @@ -1804,6 +1831,14 @@ def load_lora_weights(
low_cpu_mem_usage=low_cpu_mem_usage,
)

if len(transformer_norm_state_dict) > 0:
self.load_norm_into_transformer(
transformer_norm_state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
)

text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
Expand Down Expand Up @@ -1860,6 +1895,15 @@ def load_lora_into_transformer(
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
def load_norm_into_transformer(
cls,
state_dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state_dict,
norm_state_dict,

transformer: torch.nn.Module,
):
print(state_dict.keys())
transformer.load_state_dict(state_dict, strict=True)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
def load_lora_into_text_encoder(
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans

rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, we don't train the lora_B bias, but the control loras have trained biases. In this case, the assumption of being able to index val.shape[1] is incorrect since bias is a ndim=1 tensor

rank[key] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
Expand All @@ -248,7 +249,7 @@ def __init__(
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
Expand All @@ -261,7 +262,7 @@ def __init__(
)

self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)

self.transformer_blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -449,13 +450,15 @@ def forward(
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None

temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
"FluxFillPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
Expand Down Expand Up @@ -524,6 +525,7 @@
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -39,6 +40,7 @@
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
else:
Expand Down
86 changes: 82 additions & 4 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
Expand Down Expand Up @@ -529,6 +529,41 @@ def prepare_latents(

return latents, latent_image_ids

# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way is probably to have 3 separate FluxControlPipeline? but maybe @DN6 can come up with better ideas!

def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)

image_batch_size = image.shape[0]

if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt

image = image.repeat_interleave(repeat_by, dim=0)

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)

return image

@property
def guidance_scale(self):
return self._guidance_scale
Expand Down Expand Up @@ -556,9 +591,11 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
control_image: PipelineImageInput = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
control_latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -595,6 +632,14 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -667,6 +712,7 @@ def __call__(

device = self._execution_device

# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
Expand All @@ -686,7 +732,35 @@ def __call__(
)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
num_channels_latents = (
self.transformer.config.in_channels // 4
if control_image is None
else self.transformer.config.in_channels // 8
)

if control_image is not None and control_latents is None:
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)

control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator)
control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

height_control_image, width_control_image = control_latents.shape[2:]
control_latents = self._pack_latents(
control_latents,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand Down Expand Up @@ -732,11 +806,16 @@ def __call__(
if self.interrupt:
continue

if control_latents is not None:
latent_model_input = torch.cat([latents, control_latents], dim=2)
else:
latent_model_input = latents

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latents,
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down Expand Up @@ -774,7 +853,6 @@ def __call__(

if output_type == "latent":
image = latents

else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def __call__(
device = self._execution_device
dtype = self.transformer.dtype

# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
Expand Down
Loading
Loading