-
Couldn't load subscription status.
- Fork 6.5k
Feature flux controlnet img2img and inpaint pipeline #9408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yiyixuxu
merged 16 commits into
huggingface:main
from
ighoshsubho:feature/flux_controlnet_image_inpaint
Sep 17, 2024
+2,634
−1
Merged
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
070cc7c
Implemented FLUX controlnet support to Img2Img pipeline
ighoshsubho b7e7da3
Init remove inpainting
ighoshsubho f80d17c
Flux controlnet img2img and inpaint pipeline
ighoshsubho ba3f177
Merge branch 'main' into feature/flux_controlnet_image_inpaint
ighoshsubho 72ac92a
Merge branch 'huggingface:main' into feature/flux_controlnet_image_in…
ighoshsubho d71ef15
style and quality enforced
ighoshsubho d55b5cb
doc string added for controlnet flux inpaint and img2img pipelines, a…
ighoshsubho a30ca65
added example usecases in inpaint and img2img pipeline
ighoshsubho b29d96f
make fix copies added
ighoshsubho 21862e2
Merge branch 'main' into feature/flux_controlnet_image_inpaint
ighoshsubho e543d3a
docs added for img2img and inpaint, also added docs to pipelines
ighoshsubho 245f97b
Fix tests and minor bugs
ighoshsubho 30be607
Merge branch 'main' into feature/flux_controlnet_image_inpaint
ighoshsubho b466a7a
fix flux docs
ighoshsubho 833d348
Flux tests fix
ighoshsubho b385dcf
Merge branch 'main' into feature/flux_controlnet_image_inpaint
ighoshsubho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
368 changes: 368 additions & 0 deletions
368
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,368 @@ | ||
| import inspect | ||
| from typing import Any, Callable, Dict, List, Optional, Union | ||
|
|
||
| import PIL | ||
| import numpy as np | ||
| import torch | ||
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | ||
|
|
||
| from diffusers import FluxControlNetPipeline, AutoencoderKL, FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler | ||
| from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel | ||
| from diffusers.pipelines.flux import FluxPipelineOutput | ||
| from diffusers.utils import logging, randn_tensor | ||
| from diffusers.utils.import_utils import is_torch_xla_available | ||
|
|
||
| if is_torch_xla_available(): | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| XLA_AVAILABLE = True | ||
| else: | ||
| XLA_AVAILABLE = False | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | ||
| def retrieve_timesteps( | ||
| scheduler, | ||
| num_inference_steps: Optional[int] = None, | ||
| device: Optional[Union[str, torch.device]] = None, | ||
| timesteps: Optional[List[int]] = None, | ||
| sigmas: Optional[List[float]] = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | ||
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | ||
|
|
||
| Args: | ||
| scheduler (`SchedulerMixin`): | ||
| The scheduler to get timesteps from. | ||
| num_inference_steps (`int`): | ||
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | ||
| must be `None`. | ||
| device (`str` or `torch.device`, *optional*): | ||
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | ||
| timesteps (`List[int]`, *optional*): | ||
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | ||
| `num_inference_steps` and `sigmas` must be `None`. | ||
| sigmas (`List[float]`, *optional*): | ||
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | ||
| `num_inference_steps` and `timesteps` must be `None`. | ||
|
|
||
| Returns: | ||
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | ||
| second element is the number of inference steps. | ||
| """ | ||
| if timesteps is not None and sigmas is not None: | ||
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | ||
| if timesteps is not None: | ||
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
| if not accepts_timesteps: | ||
| raise ValueError( | ||
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
| f" timestep schedules. Please check whether you are using the correct scheduler." | ||
| ) | ||
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| num_inference_steps = len(timesteps) | ||
| elif sigmas is not None: | ||
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
| if not accept_sigmas: | ||
| raise ValueError( | ||
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
| f" sigmas schedules. Please check whether you are using the correct scheduler." | ||
| ) | ||
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| num_inference_steps = len(timesteps) | ||
| else: | ||
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| return timesteps, num_inference_steps | ||
|
|
||
| class FluxControlNetImg2ImgPipeline(FluxControlNetPipeline): | ||
| def __init__( | ||
| self, | ||
| scheduler: FlowMatchEulerDiscreteScheduler, | ||
| vae: AutoencoderKL, | ||
| text_encoder: CLIPTextModel, | ||
| tokenizer: CLIPTokenizer, | ||
| text_encoder_2: T5EncoderModel, | ||
| tokenizer_2: T5TokenizerFast, | ||
| transformer: FluxTransformer2DModel, | ||
| controlnet: Union[FluxControlNetModel, List[FluxControlNetModel], FluxMultiControlNetModel], | ||
| ): | ||
| super().__init__( | ||
| scheduler=scheduler, | ||
| vae=vae, | ||
| text_encoder=text_encoder, | ||
| tokenizer=tokenizer, | ||
| text_encoder_2=text_encoder_2, | ||
| tokenizer_2=tokenizer_2, | ||
| transformer=transformer, | ||
| controlnet=controlnet, | ||
| ) | ||
|
|
||
| def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): | ||
| if isinstance(generator, list): | ||
| image_latents = [ | ||
| self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) | ||
| for i in range(image.shape[0]) | ||
| ] | ||
| image_latents = torch.cat(image_latents, dim=0) | ||
| else: | ||
| image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) | ||
|
|
||
| image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor | ||
|
|
||
| return image_latents | ||
|
|
||
| def get_timesteps(self, num_inference_steps, strength, device): | ||
| # get the original timestep using init_timestep | ||
| init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | ||
|
|
||
| t_start = max(num_inference_steps - init_timestep, 0) | ||
| timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] | ||
|
|
||
| return timesteps, num_inference_steps - t_start | ||
|
|
||
| def prepare_latents( | ||
ighoshsubho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| image, | ||
| timestep, | ||
| batch_size, | ||
| num_images_per_prompt, | ||
| dtype, | ||
| device, | ||
| generator, | ||
| ): | ||
| image = image.to(device=device, dtype=dtype) | ||
| init_latents = self._encode_vae_image(image, generator=generator) | ||
| init_latents = init_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) | ||
|
|
||
| shape = init_latents.shape | ||
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
|
|
||
| # get latents | ||
| latents = self.scheduler.add_noise(init_latents, noise, timestep) | ||
| latents = self._pack_latents(latents, batch_size * num_images_per_prompt, shape[1], shape[2], shape[3]) | ||
| latent_image_ids = self._prepare_latent_image_ids(batch_size * num_images_per_prompt, shape[2], shape[3], device, dtype) | ||
|
|
||
| return latents, latent_image_ids | ||
|
|
||
| @torch.no_grad() | ||
| def __call__( | ||
| self, | ||
| prompt: Union[str, List[str]] = None, | ||
| prompt_2: Optional[Union[str, List[str]]] = None, | ||
| image: Union[torch.FloatTensor, PIL.Image.Image] = None, | ||
| control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, | ||
| height: Optional[int] = None, | ||
| width: Optional[int] = None, | ||
| strength: float = 0.8, | ||
| num_inference_steps: int = 28, | ||
| guidance_scale: float = 7.0, | ||
| control_mode: Optional[Union[int, List[int]]] = None, | ||
| controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | ||
| num_images_per_prompt: Optional[int] = 1, | ||
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | ||
| latents: Optional[torch.FloatTensor] = None, | ||
| prompt_embeds: Optional[torch.FloatTensor] = None, | ||
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | ||
| output_type: Optional[str] = "pil", | ||
| return_dict: bool = True, | ||
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | ||
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| max_sequence_length: int = 512, | ||
| ): | ||
| # 1. Check inputs | ||
| self.check_inputs( | ||
| prompt, | ||
| prompt_2, | ||
| strength, | ||
| height, | ||
| width, | ||
| prompt_embeds=prompt_embeds, | ||
| pooled_prompt_embeds=pooled_prompt_embeds, | ||
| callback_on_step_end_tensor_inputs=None, | ||
| max_sequence_length=max_sequence_length, | ||
| ) | ||
|
|
||
| # 2. Define call parameters | ||
| if prompt is not None and isinstance(prompt, str): | ||
| batch_size = 1 | ||
| elif prompt is not None and isinstance(prompt, list): | ||
| batch_size = len(prompt) | ||
| else: | ||
| batch_size = prompt_embeds.shape[0] | ||
|
|
||
| device = self._execution_device | ||
| dtype = self.transformer.dtype | ||
|
|
||
| # 3. Encode input prompt | ||
| lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | ||
| prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( | ||
| prompt=prompt, | ||
| prompt_2=prompt_2, | ||
| device=device, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| max_sequence_length=max_sequence_length, | ||
| lora_scale=lora_scale, | ||
| ) | ||
|
|
||
| # 4. Preprocess image | ||
| height, width = self.image_processor.get_image_dimensions(image) | ||
| image = self.image_processor.preprocess(image) | ||
|
|
||
| # 5. Prepare control image | ||
| control_image = self.prepare_image( | ||
| image=control_image, | ||
| width=width, | ||
| height=height, | ||
| batch_size=batch_size * num_images_per_prompt, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| # 6. Prepare timesteps | ||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | ||
| image_seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) | ||
| mu = self.calculate_shift( | ||
| image_seq_len, | ||
| self.scheduler.config.base_image_seq_len, | ||
| self.scheduler.config.max_image_seq_len, | ||
| self.scheduler.config.base_shift, | ||
| self.scheduler.config.max_shift, | ||
| ) | ||
| timesteps, num_inference_steps = retrieve_timesteps( | ||
| self.scheduler, | ||
| num_inference_steps, | ||
| device, | ||
| timesteps, | ||
| sigmas, | ||
| mu=mu, | ||
| ) | ||
| timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | ||
|
|
||
| if num_inference_steps < 1: | ||
| raise ValueError( | ||
| f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" | ||
| f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." | ||
| ) | ||
| latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | ||
|
|
||
| # 7. Prepare latent variables | ||
| num_channels_latents = self.transformer.config.in_channels // 4 | ||
|
|
||
| latents, latent_image_ids = self.prepare_latents( | ||
| image, | ||
| latent_timestep, | ||
| batch_size * num_images_per_prompt, | ||
| num_channels_latents, | ||
| height, | ||
| width, | ||
| prompt_embeds.dtype, | ||
| device, | ||
| generator, | ||
| latents, | ||
| ) | ||
|
|
||
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | ||
| self._num_timesteps = len(timesteps) | ||
|
|
||
| # handle guidance | ||
| if self.transformer.config.guidance_embeds: | ||
| guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) | ||
| guidance = guidance.expand(latents.shape[0]) | ||
| else: | ||
| guidance = None | ||
| with self.progress_bar(total=num_inference_steps) as progress_bar: | ||
| for i, t in enumerate(timesteps): | ||
| if self.interrupt: | ||
| continue | ||
|
|
||
| timestep = t.expand(latents.shape[0]).to(latents.dtype) | ||
|
|
||
| # Expand the latents if we are doing classifier free guidance | ||
| latent_model_input = torch.cat([latents] * 2) if self.transformer.config.guidance_embeds else latents | ||
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
|
||
| # ControlNet(s) inference | ||
| controlnet_block_samples, controlnet_single_block_samples = self.controlnet( | ||
| hidden_states=latent_model_input, | ||
| controlnet_cond=control_image, | ||
| controlnet_mode=control_mode, | ||
| conditioning_scale=controlnet_conditioning_scale, | ||
| timestep=timestep / 1000, | ||
| guidance=torch.tensor([guidance_scale], device=device).expand(latents.shape[0]), | ||
| pooled_projections=pooled_prompt_embeds, | ||
| encoder_hidden_states=prompt_embeds, | ||
| txt_ids=text_ids, | ||
| img_ids=latent_image_ids, | ||
| joint_attention_kwargs=cross_attention_kwargs, | ||
| return_dict=False, | ||
| ) | ||
|
|
||
| # Predict the noise residual | ||
| noise_pred = self.transformer( | ||
| hidden_states=latent_model_input, | ||
| timestep=timestep / 1000, | ||
| guidance=torch.tensor([guidance_scale], device=device).expand(latents.shape[0]), | ||
| pooled_projections=pooled_prompt_embeds, | ||
| encoder_hidden_states=prompt_embeds, | ||
| controlnet_block_samples=controlnet_block_samples, | ||
| controlnet_single_block_samples=controlnet_single_block_samples, | ||
| txt_ids=text_ids, | ||
| img_ids=latent_image_ids, | ||
| joint_attention_kwargs=cross_attention_kwargs, | ||
| return_dict=False, | ||
| )[0] | ||
|
|
||
| # Perform guidance | ||
| if self.transformer.config.guidance_embeds: | ||
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
|
||
| # Compute the previous noisy sample x_t -> x_t-1 | ||
| latents_dtype = latents.dtype | ||
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | ||
|
|
||
| if latents.dtype != latents_dtype: | ||
| if torch.backends.mps.is_available(): | ||
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | ||
| latents = latents.to(latents_dtype) | ||
|
|
||
| if callback_on_step_end is not None: | ||
| callback_kwargs = {} | ||
| for k in callback_on_step_end_tensor_inputs: | ||
| callback_kwargs[k] = locals()[k] | ||
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | ||
|
|
||
| latents = callback_outputs.pop("latents", latents) | ||
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | ||
|
|
||
| # call the callback, if provided | ||
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||
| progress_bar.update() | ||
|
|
||
| if XLA_AVAILABLE: | ||
| xm.mark_step() | ||
|
|
||
| if output_type == "latent": | ||
| image = latents | ||
|
|
||
| else: | ||
| latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | ||
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | ||
| image = self.vae.decode(latents, return_dict=False)[0] | ||
| image = self.image_processor.postprocess(image, output_type=output_type) | ||
|
|
||
| # Offload all models | ||
| self.maybe_free_model_hooks() | ||
|
|
||
| if not return_dict: | ||
| return (image,) | ||
|
|
||
| return FluxPipelineOutput(images=image) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.