-
Couldn't load subscription status.
- Fork 6.5k
Flux Fill, Canny, Depth, Redux #9985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
2829679
be67dbd
f56ffb1
7e4df06
9ea52da
217e90c
b4f1cbf
414b30b
6b02ac2
3169bf5
f7f006d
8bb940e
9e615fd
6d168db
89fd970
73cfc51
c94966f
1b427e2
201d8dc
27021ac
8f90b2a
36b6228
bc503cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1787,14 +1787,41 @@ def load_lora_weights( | |||||
| pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs | ||||||
| ) | ||||||
|
|
||||||
| is_correct_format = all("lora" in key for key in state_dict.keys()) | ||||||
| if not is_correct_format: | ||||||
| has_lora_keys = any("lora" in key for key in state_dict.keys()) | ||||||
|
|
||||||
| # Flux Control LoRAs also have norm keys | ||||||
| supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] | ||||||
| has_norm_keys = any(norm_key in key for key in state_dict.keys() for norm_key in supported_norm_keys) | ||||||
|
|
||||||
| if not (has_lora_keys or has_norm_keys): | ||||||
|
||||||
| raise ValueError("Invalid LoRA checkpoint.") | ||||||
|
|
||||||
| transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} | ||||||
| if len(transformer_state_dict) > 0: | ||||||
| def prune_state_dict_(state_dict): | ||||||
| pruned_keys = [] | ||||||
| for key in list(state_dict.keys()): | ||||||
| is_lora_key_present = "lora" in key | ||||||
| is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys) | ||||||
| if not is_lora_key_present and not is_norm_key_present: | ||||||
| state_dict.pop(key) | ||||||
| pruned_keys.append(key) | ||||||
| return pruned_keys | ||||||
|
|
||||||
| pruned_keys = prune_state_dict_(state_dict) | ||||||
| if len(pruned_keys) > 0: | ||||||
| logger.warning( | ||||||
| f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}" | ||||||
| ) | ||||||
|
||||||
|
|
||||||
| transformer_lora_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k and "lora" in k} | ||||||
| transformer_norm_state_dict = { | ||||||
| k: v | ||||||
| for k, v in state_dict.items() | ||||||
| if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys) | ||||||
| } | ||||||
|
||||||
|
|
||||||
| if len(transformer_lora_state_dict) > 0: | ||||||
| self.load_lora_into_transformer( | ||||||
| state_dict, | ||||||
| transformer_lora_state_dict, | ||||||
| network_alphas=network_alphas, | ||||||
| transformer=getattr(self, self.transformer_name) | ||||||
| if not hasattr(self, "transformer") | ||||||
|
|
@@ -1804,6 +1831,14 @@ def load_lora_weights( | |||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||
| ) | ||||||
|
|
||||||
| if len(transformer_norm_state_dict) > 0: | ||||||
| self.load_norm_into_transformer( | ||||||
| transformer_norm_state_dict, | ||||||
| transformer=getattr(self, self.transformer_name) | ||||||
| if not hasattr(self, "transformer") | ||||||
| else self.transformer, | ||||||
| ) | ||||||
|
|
||||||
| text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} | ||||||
| if len(text_encoder_state_dict) > 0: | ||||||
| self.load_lora_into_text_encoder( | ||||||
|
|
@@ -1860,6 +1895,15 @@ def load_lora_into_transformer( | |||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||
| ) | ||||||
|
|
||||||
| @classmethod | ||||||
| def load_norm_into_transformer( | ||||||
| cls, | ||||||
| state_dict, | ||||||
|
||||||
| state_dict, | |
| norm_state_dict, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -216,7 +216,9 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |
|
|
||
| rank = {} | ||
| for key, val in state_dict.items(): | ||
| if "lora_B" in key: | ||
| # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. | ||
| # Bias layers in LoRA only have a single dimension | ||
| if "lora_B" in key and val.ndim > 1: | ||
|
||
| rank[key] = val.shape[1] | ||
|
|
||
| if network_alphas is not None and len(network_alphas) >= 1: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| import torch | ||
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | ||
|
|
||
| from ...image_processor import VaeImageProcessor | ||
| from ...image_processor import PipelineImageInput, VaeImageProcessor | ||
| from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin | ||
| from ...models.autoencoders import AutoencoderKL | ||
| from ...models.transformers import FluxTransformer2DModel | ||
|
|
@@ -529,6 +529,41 @@ def prepare_latents( | |
|
|
||
| return latents, latent_image_ids | ||
|
|
||
| # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image | ||
|
||
| def prepare_image( | ||
| self, | ||
| image, | ||
| width, | ||
| height, | ||
| batch_size, | ||
| num_images_per_prompt, | ||
| device, | ||
| dtype, | ||
| do_classifier_free_guidance=False, | ||
| guess_mode=False, | ||
| ): | ||
| if isinstance(image, torch.Tensor): | ||
| pass | ||
| else: | ||
| image = self.image_processor.preprocess(image, height=height, width=width) | ||
|
|
||
| image_batch_size = image.shape[0] | ||
|
|
||
| if image_batch_size == 1: | ||
| repeat_by = batch_size | ||
| else: | ||
| # image batch size is the same as prompt batch size | ||
| repeat_by = num_images_per_prompt | ||
|
|
||
| image = image.repeat_interleave(repeat_by, dim=0) | ||
|
|
||
| image = image.to(device=device, dtype=dtype) | ||
|
|
||
| if do_classifier_free_guidance and not guess_mode: | ||
| image = torch.cat([image] * 2) | ||
|
|
||
| return image | ||
|
|
||
| @property | ||
| def guidance_scale(self): | ||
| return self._guidance_scale | ||
|
|
@@ -556,9 +591,11 @@ def __call__( | |
| num_inference_steps: int = 28, | ||
| timesteps: List[int] = None, | ||
| guidance_scale: float = 3.5, | ||
| control_image: PipelineImageInput = None, | ||
| num_images_per_prompt: Optional[int] = 1, | ||
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | ||
| latents: Optional[torch.FloatTensor] = None, | ||
| control_latents: Optional[torch.FloatTensor] = None, | ||
| prompt_embeds: Optional[torch.FloatTensor] = None, | ||
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | ||
| output_type: Optional[str] = "pil", | ||
|
|
@@ -595,6 +632,14 @@ def __call__( | |
| Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
| 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
| usually at the expense of lower image quality. | ||
| control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: | ||
| `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): | ||
| The ControlNet input condition to provide guidance to the `unet` for generation. If the type is | ||
| specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted | ||
| as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or | ||
| width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, | ||
| images must be passed as a list such that each element of the list can be correctly batched for input | ||
| to a single ControlNet. | ||
| num_images_per_prompt (`int`, *optional*, defaults to 1): | ||
| The number of images to generate per prompt. | ||
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | ||
|
|
@@ -667,6 +712,7 @@ def __call__( | |
|
|
||
| device = self._execution_device | ||
|
|
||
| # 3. Prepare text embeddings | ||
| lora_scale = ( | ||
| self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None | ||
| ) | ||
|
|
@@ -686,7 +732,35 @@ def __call__( | |
| ) | ||
|
|
||
| # 4. Prepare latent variables | ||
| num_channels_latents = self.transformer.config.in_channels // 4 | ||
| num_channels_latents = ( | ||
| self.transformer.config.in_channels // 4 | ||
| if control_image is None | ||
| else self.transformer.config.in_channels // 8 | ||
| ) | ||
|
|
||
| if control_image is not None and control_latents is None: | ||
| control_image = self.prepare_image( | ||
| image=control_image, | ||
| width=width, | ||
| height=height, | ||
| batch_size=batch_size * num_images_per_prompt, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| device=device, | ||
| dtype=self.vae.dtype, | ||
| ) | ||
|
|
||
| control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator) | ||
| control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor | ||
|
|
||
| height_control_image, width_control_image = control_latents.shape[2:] | ||
| control_latents = self._pack_latents( | ||
| control_latents, | ||
| batch_size * num_images_per_prompt, | ||
| num_channels_latents, | ||
| height_control_image, | ||
| width_control_image, | ||
| ) | ||
|
|
||
| latents, latent_image_ids = self.prepare_latents( | ||
| batch_size * num_images_per_prompt, | ||
| num_channels_latents, | ||
|
|
@@ -732,11 +806,16 @@ def __call__( | |
| if self.interrupt: | ||
| continue | ||
|
|
||
| if control_latents is not None: | ||
| latent_model_input = torch.cat([latents, control_latents], dim=2) | ||
| else: | ||
| latent_model_input = latents | ||
|
|
||
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
| timestep = t.expand(latents.shape[0]).to(latents.dtype) | ||
|
|
||
| noise_pred = self.transformer( | ||
| hidden_states=latents, | ||
| hidden_states=latent_model_input, | ||
| timestep=timestep / 1000, | ||
| guidance=guidance, | ||
| pooled_projections=pooled_prompt_embeds, | ||
|
|
@@ -774,7 +853,6 @@ def __call__( | |
|
|
||
| if output_type == "latent": | ||
| image = latents | ||
|
|
||
| else: | ||
| latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | ||
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For supporting the additional norm layers. Also FYI, the norm layers from the LoRA are the exact same numerically to Flux1-Canny-Dev and Flux1-Depth-Dev, but different from Flux1-Dev (the model for which the lora is intended), so we cannot do without this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly. Thanks for confirming!