From 070cc7cfb296af5b8567286d25705aff6b08e8c9 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Tue, 10 Sep 2024 23:09:53 +0530 Subject: [PATCH 01/11] Implemented FLUX controlnet support to Img2Img pipeline --- src/diffusers/pipelines/flux/__init__.py | 2 + ...pipeline_flux_controlnet_image_to_image.py | 368 ++++++++++++++++++ 2 files changed, 370 insertions(+) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index e43a7ab753cd..534b160e93af 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -37,6 +37,8 @@ from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline + from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline + from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py new file mode 100644 index 000000000000..1a07739afb21 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -0,0 +1,368 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers import FluxControlNetPipeline, AutoencoderKL, FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel +from diffusers.pipelines.flux import FluxPipelineOutput +from diffusers.utils import logging, randn_tensor +from diffusers.utils.import_utils import is_torch_xla_available + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +class FluxControlNetImg2ImgPipeline(FluxControlNetPipeline): + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[FluxControlNetModel, List[FluxControlNetModel], FluxMultiControlNetModel], + ): + super().__init__( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + controlnet=controlnet, + ) + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + ): + image = image.to(device=device, dtype=dtype) + init_latents = self._encode_vae_image(image, generator=generator) + init_latents = init_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = self._pack_latents(latents, batch_size * num_images_per_prompt, shape[1], shape[2], shape[3]) + latent_image_ids = self._prepare_latent_image_ids(batch_size * num_images_per_prompt, shape[2], shape[3], device, dtype) + + return latents, latent_image_ids + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 28, + guidance_scale: float = 7.0, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + max_sequence_length: int = 512, + ): + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=max_sequence_length, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Encode input prompt + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Preprocess image + height, width = self.image_processor.get_image_dimensions(image) + image = self.image_processor.preprocess(image) + + # 5. Prepare control image + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + mu = self.calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, latent_image_ids = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.transformer.config.guidance_embeds else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # ControlNet(s) inference + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latent_model_input, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=controlnet_conditioning_scale, + timestep=timestep / 1000, + guidance=torch.tensor([guidance_scale], device=device).expand(latents.shape[0]), + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=cross_attention_kwargs, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=torch.tensor([guidance_scale], device=device).expand(latents.shape[0]), + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # Perform guidance + if self.transformer.config.guidance_embeds: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file From b7e7da3260ef84ec5dec910d306112820082d3f0 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Tue, 10 Sep 2024 23:12:21 +0530 Subject: [PATCH 02/11] Init remove inpainting --- src/diffusers/pipelines/flux/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 534b160e93af..e6d681145215 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -38,7 +38,6 @@ from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline - from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline else: import sys From f80d17c682c564c4073e6fb92d088165c8402263 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Wed, 11 Sep 2024 22:32:15 +0530 Subject: [PATCH 03/11] Flux controlnet img2img and inpaint pipeline --- src/diffusers/__init__.py | 4 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/flux/__init__.py | 3 + ...pipeline_flux_controlnet_image_to_image.py | 698 +++++++++-- .../pipeline_flux_controlnet_inpainting.py | 1050 +++++++++++++++++ .../test_controlnet_flux_img2img.py | 317 +++++ .../test_controlnet_flux_inpaint.py | 291 +++++ 7 files changed, 2274 insertions(+), 93 deletions(-) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py create mode 100644 tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py create mode 100644 tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5b505b6a1f3a..1af5b441917c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -259,6 +259,8 @@ "CogVideoXVideoToVideoPipeline", "CycleDiffusionPipeline", "FluxControlNetPipeline", + "FluxControlNetInpaintPipeline", + "FluxControlNetImg2ImgPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", @@ -707,6 +709,8 @@ CogVideoXVideoToVideoPipeline, CycleDiffusionPipeline, FluxControlNetPipeline, + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e4d37a905b86..56886c031569 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -127,6 +127,8 @@ ] _import_structure["flux"] = [ "FluxControlNetPipeline", + "FluxControlNetImg2ImgPipeline", + "FluxControlNetInpaintPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", @@ -501,7 +503,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline + from .flux import FluxControlNetPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index e6d681145215..77e66bf144ba 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -26,6 +26,8 @@ _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] + _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] + _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -38,6 +40,7 @@ from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline + from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 1a07739afb21..b702e9d19dde 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -1,16 +1,35 @@ import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast - -from diffusers import FluxControlNetPipeline, AutoencoderKL, FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler -from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel -from diffusers.pipelines.flux import FluxPipelineOutput -from diffusers.utils import logging, randn_tensor -from diffusers.utils.import_utils import is_torch_xla_available +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.models.attention_processor import AttnProcessor2_0 + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -19,8 +38,61 @@ else: XLA_AVAILABLE = False -logger = logging.get_logger(__name__) - +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-controlnet-canny", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> init_image = load_image("https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg") + + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.6, + ... strength=0.7, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_img2img.png") + ``` +""" + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -80,7 +152,12 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps -class FluxControlNetImg2ImgPipeline(FluxControlNetPipeline): +class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, @@ -90,77 +167,467 @@ def __init__( text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, - controlnet: Union[FluxControlNetModel, List[FluxControlNetModel], FluxMultiControlNetModel], + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], ): - super().__init__( - scheduler=scheduler, + super().__init__() + + self.register_modules( vae=vae, text_encoder=text_encoder, - tokenizer=tokenizer, text_encoder_2=text_encoder_2, + tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, + scheduler=scheduler, controlnet=controlnet, ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor return image_latents + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + init_timestep = min(num_inference_steps * strength, num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) + t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + height, + width, + control_image, + control_mode, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + pooled_prompt_embeds=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Check `image` + is_compiled = hasattr(torch.functional.F, "scaled_dot_product_attention") and isinstance( + self.transformer.modulators.attn1.processor, + AttnProcessor2_0, + ) + if ( + isinstance(image, (torch.Tensor, PIL.Image.Image, np.ndarray)) + and isinstance(control_image, (torch.Tensor, PIL.Image.Image, np.ndarray)) + and not is_compiled + ): + raise TypeError( + f"image and control_image should be passed as separate arguments. But got {type(image)} and {type(control_image)}." + ) + + # Check `control_image` + if isinstance(self.controlnet, FluxControlNetModel): + self.check_image(control_image, prompt, prompt_embeds) + elif isinstance(self.controlnet, FluxMultiControlNetModel): + if not isinstance(control_image, list): + raise TypeError("For multiple controlnets: `control_image` must be type `list`") + if len(control_image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `control_image` must have the same length as the number of controlnets." + ) + for image_ in control_image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `control_mode` + if control_mode is not None and isinstance(self.controlnet, FluxMultiControlNetModel): + if not isinstance(control_mode, list): + raise ValueError("You have multiple ControlNets, but only provided one control mode.") + if len(control_mode) != len(self.controlnet.nets): + raise ValueError("Number of control modes does not match the number of ControlNets.") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + def prepare_latents( self, image, timestep, batch_size, - num_images_per_prompt, + num_channels_latents, + height, + width, dtype, device, generator, + latents=None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + image = image.to(device=device, dtype=dtype) - init_latents = self._encode_vae_image(image, generator=generator) - init_latents = init_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) - shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) - # get latents - latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = self._pack_latents(latents, batch_size * num_images_per_prompt, shape[1], shape[2], shape[3]) - latent_image_ids = self._prepare_latent_image_ids(batch_size * num_images_per_prompt, shape[2], shape[3], device, dtype) + image_batch_size = image.shape[0] - return latents, latent_image_ids + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, - strength: float = 0.8, + strength: float = 0.6, num_inference_steps: int = 28, + timesteps: List[int] = None, guidance_scale: float = 7.0, control_mode: Optional[Union[int, List[int]]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, @@ -171,25 +638,34 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - cross_attention_kwargs: Optional[Dict[str, Any]] = None, max_sequence_length: int = 512, ): - # 1. Check inputs + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + self.check_inputs( prompt, prompt_2, + image, strength, height, width, + control_image, + control_mode, + callback_on_step_end_tensor_inputs, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=None, max_sequence_length=max_sequence_length, ) - # 2. Define call parameters + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -200,36 +676,101 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype - # 3. Encode input prompt - lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) - # 4. Preprocess image - height, width = self.image_processor.get_image_dimensions(image) - image = self.image_processor.preprocess(image) + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) - # 5. Prepare control image - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=dtype, - ) + num_channels_latents = self.transformer.config.in_channels // 4 + + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + height, width = control_image_.shape[-2:] + + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) - # 6. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) - mu = self.calculate_shift( + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, @@ -246,18 +787,10 @@ def __call__( ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # 7. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( - image, + init_image, latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, @@ -272,66 +805,50 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None + guidance = torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - - timestep = t.expand(latents.shape[0]).to(latents.dtype) - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.transformer.config.guidance_embeds else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + timestep = t.expand(latents.shape[0]).to(latents.dtype) - # ControlNet(s) inference controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latent_model_input, + hidden_states=latents, controlnet_cond=control_image, controlnet_mode=control_mode, conditioning_scale=controlnet_conditioning_scale, timestep=timestep / 1000, - guidance=torch.tensor([guidance_scale], device=device).expand(latents.shape[0]), + guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - joint_attention_kwargs=cross_attention_kwargs, + joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ) - # Predict the noise residual noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep / 1000, - guidance=torch.tensor([guidance_scale], device=device).expand(latents.shape[0]), + guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, txt_ids=text_ids, img_ids=latent_image_ids, - joint_attention_kwargs=cross_attention_kwargs, + joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] - # Perform guidance - if self.transformer.config.guidance_embeds: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # Compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: @@ -343,7 +860,6 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -352,14 +868,12 @@ def __call__( if output_type == "latent": image = latents - else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py new file mode 100644 index 000000000000..e72c1673712e --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -0,0 +1,1050 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-controlnet-canny-inpaint", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg") + >>> init_image = load_image("https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png") + >>> mask_image = load_image("https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png") + + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.6, + ... strength=0.7, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_inpaint.png") + ``` +""" + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + ): + super().__init__() + + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + controlnet=controlnet, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + control_image, + height, + width, + strength, + output_type, + control_mode, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + + if isinstance(image, (torch.Tensor, PIL.Image.Image, np.ndarray)): + raise ValueError( + f"image should be passed as separate argument. But got {type(image)}." + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined." + ) + + if isinstance(self.controlnet, FluxControlNetModel): + self.check_image(control_image, prompt, prompt_embeds) + elif isinstance(self.controlnet, FluxMultiControlNetModel): + if not isinstance(control_image, list): + raise TypeError("For multiple controlnets: `control_image` must be type `list`") + if len(control_image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `control_image` must have the same length as the number of controlnets." + ) + for image_ in control_image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + if control_mode is not None and isinstance(self.controlnet, FluxMultiControlNetModel): + if not isinstance(control_mode, list): + raise ValueError("You have multiple ControlNets, but only provided one control mode.") + if len(control_mode) != len(self.controlnet.nets): + raise ValueError("Number of control modes does not match the number of ControlNets.") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + padding_mask_crop: Optional[int] = None, + timesteps: List[int] = None, + num_inference_steps: int = 28, + guidance_scale: float = 7.0, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + control_image, + height, + width, + strength, + output_type, + control_mode, + callback_on_step_end_tensor_inputs, + prompt_embeds, + pooled_prompt_embeds, + padding_mask_crop, + max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + height, width = control_image.shape[-2:] + + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + ) + height, width = control_image_.shape[-2:] + + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + # 5.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=controlnet_conditioning_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py new file mode 100644 index 000000000000..940f74b807a4 --- /dev/null +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -0,0 +1,317 @@ +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxControlNetImg2ImgPipeline, + FluxTransformer2DModel, +) +from diffusers.models import FluxControlNetModel +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_gpu, + slow, + torch_device, +) +from diffusers.utils.torch_utils import randn_tensor + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlNetImg2ImgPipeline + + params = frozenset( + [ + "prompt", + "image", + "control_image", + "height", + "width", + "strength", + "guidance_scale", + "num_inference_steps", + "prompt_embeds", + "pooled_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "image", "control_image"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + + torch.manual_seed(0) + controlnet = FluxControlNetModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "controlnet": controlnet, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = randn_tensor( + (1, 3, 32, 32), + generator=generator, + device=torch.device(device), + dtype=torch.float32, + ) + + control_image = randn_tensor( + (1, 3, 32, 32), + generator=generator, + device=torch.device(device), + dtype=torch.float32, + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "control_image": control_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.5, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + "controlnet_conditioning_scale": 0.5, + } + + return inputs + + def test_controlnet_img2img_flux(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetImg2ImgPipeline(**components) + flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32) + flux_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = flux_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + + expected_slice = np.array( + [0.5182, 0.4976, 0.4718, 0.5249, 0.5039, 0.4751, 0.5168, 0.4980, 0.4738] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, f"Expected: {expected_slice}, got: {image_slice.flatten()}" + + def test_attention_slicing_forward_pass(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetImg2ImgPipeline(**components) + flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32) + flux_pipe.set_progress_bar_config(disable=None) + + flux_pipe.enable_attention_slicing() + inputs = self.get_dummy_inputs(torch_device) + output_sliced = flux_pipe(**inputs) + image_sliced = output_sliced.images + + flux_pipe.disable_attention_slicing() + inputs = self.get_dummy_inputs(torch_device) + output = flux_pipe(**inputs) + image = output.images + + assert np.abs(image_sliced.flatten() - image.flatten()).max() < 1e-3 + + def test_inference_batch_single_identical(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetImg2ImgPipeline(**components) + flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32) + flux_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + # make batch size 1 + inputs["prompt"] = [inputs["prompt"]] + inputs["image"] = inputs["image"][:1] + inputs["control_image"] = inputs["control_image"][:1] + + output = flux_pipe(**inputs) + image = output.images + + inputs["prompt"] = inputs["prompt"] * 2 + inputs["image"] = torch.cat([inputs["image"], inputs["image"]]) + inputs["control_image"] = torch.cat([inputs["control_image"], inputs["control_image"]]) + + output_batch = flux_pipe(**inputs) + image_batch = output_batch.images + + assert np.abs(image_batch[0].flatten() - image[0].flatten()).max() < 1e-3 + assert np.abs(image_batch[1].flatten() - image[0].flatten()).max() < 1e-3 + + +@slow +@require_torch_gpu +class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = FluxControlNetModel.from_pretrained( + "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ) + pipe = FluxControlNetImg2ImgPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "A girl in city, 25 years old, cool, futuristic" + control_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ) + init_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/init_image.png" + ) + + output = pipe( + prompt, + image=init_image, + control_image=control_image, + controlnet_conditioning_scale=0.6, + strength=0.7, + num_inference_steps=3, + guidance_scale=3.5, + output_type="np", + generator=generator, + ) + + image = output.images[0] + + assert image.shape == (1024, 1024, 3) + + image_slice = image[-3:, -3:, -1].flatten() + expected_slice = np.array([0.3242, 0.3320, 0.3359, 0.3281, 0.3398, 0.3359, 0.3086, 0.3203, 0.3203]) + assert np.abs(image_slice - expected_slice).max() < 1e-2 + + def test_depth(self): + controlnet = FluxControlNetModel.from_pretrained( + "InstantX/FLUX.1-dev-Controlnet-Depth-alpha", torch_dtype=torch.bfloat16 + ) + pipe = FluxControlNetImg2ImgPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "An astronaut riding a horse on mars" + control_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Depth-alpha/resolve/main/depth.png" + ) + init_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Depth-alpha/resolve/main/astronaut_riding_horse.png" + ) + + output = pipe( + prompt, + image=init_image, + control_image=control_image, + controlnet_conditioning_scale=0.6, + strength=0.7, + num_inference_steps=3, + guidance_scale=3.5, + output_type="np", + generator=generator, + ) + + image = output.images[0] + + assert image.shape == (1024, 1024, 3) + + image_slice = image[-3:, -3:, -1].flatten() + expected_slice = np.array([0.3164, 0.3242, 0.3281, 0.3203, 0.3320, 0.3281, 0.3008, 0.3125, 0.3125]) + assert np.abs(image_slice - expected_slice).max() < 1e-2 \ No newline at end of file diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py new file mode 100644 index 000000000000..9e7bbaec79ab --- /dev/null +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py @@ -0,0 +1,291 @@ +import gc +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxControlNetInpaintPipeline, + FluxTransformer2DModel, +) +from diffusers.models import FluxControlNetModel +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + require_torch_gpu, + slow, + torch_device, +) +from diffusers.utils.torch_utils import randn_tensor + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class FluxControlNetInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlNetInpaintPipeline + + params = frozenset( + [ + "prompt", + "image", + "mask_image", + "control_image", + "height", + "width", + "strength", + "guidance_scale", + "prompt_embeds", + "pooled_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "image", "mask_image", "control_image"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + + torch.manual_seed(0) + controlnet = FluxControlNetModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + "controlnet": controlnet, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = torch.ones((1, 1, 32, 32)).to(device) + control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device, dtype=torch.float32) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "mask_image": mask_image, + "control_image": control_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.5, + "output_type": "np", + "controlnet_conditioning_scale": 0.5, + "strength": 0.8, + } + + return inputs + + def test_controlnet_inpaint_flux(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetInpaintPipeline(**components) + flux_pipe = flux_pipe.to(torch_device) + flux_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = flux_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + + expected_slice = np.array([0.5182, 0.4976, 0.4718, 0.5249, 0.5039, 0.4751, 0.5168, 0.4980, 0.4738]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_attention_slicing_forward_pass(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetInpaintPipeline(**components) + flux_pipe = flux_pipe.to(torch_device) + flux_pipe.set_progress_bar_config(disable=None) + + flux_pipe.enable_attention_slicing() + inputs = self.get_dummy_inputs(torch_device) + output_sliced = flux_pipe(**inputs) + image_sliced = output_sliced.images + + flux_pipe.disable_attention_slicing() + inputs = self.get_dummy_inputs(torch_device) + output = flux_pipe(**inputs) + image = output.images + + assert np.abs(image_sliced.flatten() - image.flatten()).max() < 1e-3 + + def test_inference_batch_single_identical(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetInpaintPipeline(**components) + flux_pipe = flux_pipe.to(torch_device) + flux_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + # make batch size 1 + inputs["prompt"] = [inputs["prompt"]] + inputs["image"] = inputs["image"][:1] + inputs["mask_image"] = inputs["mask_image"][:1] + inputs["control_image"] = inputs["control_image"][:1] + + output = flux_pipe(**inputs) + image = output.images + + inputs["prompt"] = inputs["prompt"] * 2 + inputs["image"] = torch.cat([inputs["image"], inputs["image"]]) + inputs["mask_image"] = torch.cat([inputs["mask_image"], inputs["mask_image"]]) + inputs["control_image"] = torch.cat([inputs["control_image"], inputs["control_image"]]) + + output_batch = flux_pipe(**inputs) + image_batch = output_batch.images + + assert np.abs(image_batch[0].flatten() - image[0].flatten()).max() < 1e-3 + assert np.abs(image_batch[1].flatten() - image[0].flatten()).max() < 1e-3 + + def test_flux_controlnet_inpaint_prompt_embeds(self): + components = self.get_dummy_components() + flux_pipe = FluxControlNetInpaintPipeline(**components) + flux_pipe = flux_pipe.to(torch_device) + flux_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = flux_pipe(**inputs) + image = output.images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = flux_pipe.encode_prompt(prompt, device=torch_device) + inputs["prompt_embeds"] = prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + output = flux_pipe(**inputs) + image_from_embeds = output.images[0] + + assert np.abs(image - image_from_embeds).max() < 1e-3 + + +@slow +@require_torch_gpu +class FluxControlNetInpaintPipelineSlowTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = FluxControlNetModel.from_pretrained( + "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ) + pipe = FluxControlNetInpaintPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "A girl in city, 25 years old, cool, futuristic" + control_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ) + init_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/init_image.png" + ) + mask_image = torch.ones((1, 1, init_image.height, init_image.width)) + + output = pipe( + prompt, + image=init_image, + mask_image=mask_image, + control_image=control_image, + controlnet_conditioning_scale=0.6, + strength=0.7, + num_inference_steps=3, + guidance_scale=3.5, + output_type="np", + generator=generator, + ) + + image = output.images[0] + + assert image.shape == (1024, 1024, 3) + + image_slice = image[-3:, -3:, -1].flatten() + expected_slice = np.array([0.3242, 0.3320, 0.3359, 0.3281, 0.3398, 0.3359, 0.3086, 0.3203, 0.3203]) + assert np.abs(image_slice - expected_slice).max() < 1e-2 \ No newline at end of file From d71ef15a804780db445b321936fad4c3193f154b Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Thu, 12 Sep 2024 15:18:25 +0530 Subject: [PATCH 04/11] style and quality enforced --- src/diffusers/__init__.py | 6 +- src/diffusers/pipelines/__init__.py | 9 ++- src/diffusers/pipelines/flux/__init__.py | 8 +-- ...pipeline_flux_controlnet_image_to_image.py | 33 ++++++----- .../pipeline_flux_controlnet_inpainting.py | 57 +++++++++++-------- .../test_controlnet_flux_img2img.py | 10 ++-- .../test_controlnet_flux_inpaint.py | 2 +- 7 files changed, 73 insertions(+), 52 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1af5b441917c..d876d5fc59a6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,9 +258,9 @@ "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", "CycleDiffusionPipeline", - "FluxControlNetPipeline", - "FluxControlNetInpaintPipeline", "FluxControlNetImg2ImgPipeline", + "FluxControlNetInpaintPipeline", + "FluxControlNetPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", @@ -708,9 +708,9 @@ CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CycleDiffusionPipeline, - FluxControlNetPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, + FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 56886c031569..b6d21cf19e99 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -503,7 +503,14 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxControlNetPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline + from .flux import ( + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, + FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, + FluxPipeline, + ) from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 77e66bf144ba..0ebf5ea6d78d 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -24,10 +24,10 @@ else: _import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] - _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] - _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -37,10 +37,10 @@ else: from .pipeline_flux import FluxPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline - from .pipeline_flux_img2img import FluxImg2ImgPipeline - from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_inpaint import FluxInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index b702e9d19dde..59d1b6dca70b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -1,8 +1,8 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import PIL import numpy as np +import PIL import torch from transformers import ( CLIPTextModel, @@ -31,6 +31,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import FluxPipelineOutput + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -47,11 +48,15 @@ >>> from diffusers import FluxControlNetImg2ImgPipeline >>> from diffusers.utils import load_image - >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-controlnet-canny", torch_dtype=torch.bfloat16) + >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-controlnet-canny", torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") - >>> init_image = load_image("https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg") + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + ... ) >>> prompt = "A girl in city, 25 years old, cool, futuristic" >>> image = pipe( @@ -67,6 +72,7 @@ ``` """ + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -80,6 +86,7 @@ def calculate_shift( mu = image_seq_len * m + b return mu + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -92,7 +99,8 @@ def retrieve_latents( return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") - + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -152,8 +160,8 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps -class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -387,7 +395,7 @@ def get_timesteps(self, num_inference_steps, strength, device): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start - + def check_inputs( self, prompt, @@ -478,7 +486,7 @@ def check_inputs( raise ValueError("You have multiple ControlNets, but only provided one control mode.") if len(control_mode) != len(self.controlnet.nets): raise ValueError("Number of control modes does not match the number of ControlNets.") - + @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): @@ -493,7 +501,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): ) return latent_image_ids.to(device=device, dtype=dtype) - + @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -517,7 +525,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents - + def prepare_latents( self, image, @@ -598,7 +606,7 @@ def prepare_image( image = torch.cat([image] * 2) return image - + @property def guidance_scale(self): return self._guidance_scale @@ -614,7 +622,7 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt - + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -643,7 +651,6 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): - height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -879,4 +886,4 @@ def __call__( if not return_dict: return (image,) - return FluxPipelineOutput(images=image) \ No newline at end of file + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index e72c1673712e..acf9d568ba07 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1,8 +1,8 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import PIL import numpy as np +import PIL import torch from transformers import ( CLIPTextModel, @@ -29,6 +29,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import FluxPipelineOutput + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -45,12 +46,20 @@ >>> from diffusers import FluxControlNetInpaintPipeline >>> from diffusers.utils import load_image - >>> pipe = FluxControlNetInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-controlnet-canny-inpaint", torch_dtype=torch.bfloat16) + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-controlnet-canny-inpaint", torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") - >>> control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg") - >>> init_image = load_image("https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png") - >>> mask_image = load_image("https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png") + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ... ) + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) >>> prompt = "A girl in city, 25 years old, cool, futuristic" >>> image = pipe( @@ -67,6 +76,7 @@ ``` """ + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -80,6 +90,7 @@ def calculate_shift( mu = image_seq_len * m + b return mu + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -92,7 +103,7 @@ def retrieve_latents( return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( @@ -153,6 +164,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): def __init__( self, @@ -241,7 +253,7 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds - + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, @@ -283,7 +295,7 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds - + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt def encode_prompt( self, @@ -363,7 +375,7 @@ def encode_prompt( text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids - + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -390,7 +402,7 @@ def get_timesteps(self, num_inference_steps, strength, device): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start - + def check_inputs( self, prompt, @@ -445,7 +457,7 @@ def check_inputs( raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - + if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( @@ -462,16 +474,11 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - if isinstance(image, (torch.Tensor, PIL.Image.Image, np.ndarray)): - raise ValueError( - f"image should be passed as separate argument. But got {type(image)}." - ) + raise ValueError(f"image should be passed as separate argument. But got {type(image)}.") if mask_image is None: - raise ValueError( - "`mask_image` input cannot be undefined." - ) + raise ValueError("`mask_image` input cannot be undefined.") if isinstance(self.controlnet, FluxControlNetModel): self.check_image(control_image, prompt, prompt_embeds) @@ -516,7 +523,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents - + @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): @@ -531,7 +538,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents - + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents def prepare_latents( self, @@ -583,7 +590,7 @@ def prepare_latents( image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, noise, image_latents, latent_image_ids - + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, @@ -654,7 +661,7 @@ def prepare_mask_latents( ) return mask, masked_image_latents - + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image def prepare_image( self, @@ -689,7 +696,7 @@ def prepare_image( image = torch.cat([image] * 2) return image - + @property def guidance_scale(self): return self._guidance_scale @@ -705,7 +712,7 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt - + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1047,4 +1054,4 @@ def __call__( if not return_dict: return (image,) - return FluxPipelineOutput(images=image) \ No newline at end of file + return FluxPipelineOutput(images=image) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 940f74b807a4..ae7832ebaaa3 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -174,11 +174,11 @@ def test_controlnet_img2img_flux(self): assert image.shape == (1, 32, 32, 3) - expected_slice = np.array( - [0.5182, 0.4976, 0.4718, 0.5249, 0.5039, 0.4751, 0.5168, 0.4980, 0.4738] - ) + expected_slice = np.array([0.5182, 0.4976, 0.4718, 0.5249, 0.5039, 0.4751, 0.5168, 0.4980, 0.4738]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, f"Expected: {expected_slice}, got: {image_slice.flatten()}" + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" def test_attention_slicing_forward_pass(self): components = self.get_dummy_components() @@ -314,4 +314,4 @@ def test_depth(self): image_slice = image[-3:, -3:, -1].flatten() expected_slice = np.array([0.3164, 0.3242, 0.3281, 0.3203, 0.3320, 0.3281, 0.3008, 0.3125, 0.3125]) - assert np.abs(image_slice - expected_slice).max() < 1e-2 \ No newline at end of file + assert np.abs(image_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py index 9e7bbaec79ab..dff19963987d 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py @@ -288,4 +288,4 @@ def test_canny(self): image_slice = image[-3:, -3:, -1].flatten() expected_slice = np.array([0.3242, 0.3320, 0.3359, 0.3281, 0.3398, 0.3359, 0.3086, 0.3203, 0.3203]) - assert np.abs(image_slice - expected_slice).max() < 1e-2 \ No newline at end of file + assert np.abs(image_slice - expected_slice).max() < 1e-2 From d55b5cb12f1314bddf503bf339bda2d206054048 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Thu, 12 Sep 2024 15:39:30 +0530 Subject: [PATCH 05/11] doc string added for controlnet flux inpaint and img2img pipelines, and added copied to prepare_latents in img2img pipeline --- ...pipeline_flux_controlnet_image_to_image.py | 61 +++++++++++++++++ .../pipeline_flux_controlnet_inpainting.py | 67 +++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 59d1b6dca70b..d79ba1f201f4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -526,6 +526,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents def prepare_latents( self, image, @@ -651,6 +652,66 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to modify with the pipeline. + control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_mode (`int` or `List[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index acf9d568ba07..25b4968ffd93 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -744,6 +744,73 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to inpaint. + mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels + will be preserved. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-generated masked image latents. + control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. + padding_mask_crop (`int`, *optional*): + The size of the padding to use when cropping the mask. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_mode (`int` or `List[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor From a30ca655dc597aacc7f9e016458bae4355eb7a62 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Thu, 12 Sep 2024 21:14:18 +0530 Subject: [PATCH 06/11] added example usecases in inpaint and img2img pipeline --- ...pipeline_flux_controlnet_image_to_image.py | 77 ++++++---------- .../pipeline_flux_controlnet_inpainting.py | 90 ++++++++----------- 2 files changed, 65 insertions(+), 102 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index d79ba1f201f4..7e666a049f12 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -2,7 +2,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL import torch from transformers import ( CLIPTextModel, @@ -11,8 +10,6 @@ T5TokenizerFast, ) -from diffusers.models.attention_processor import AttnProcessor2_0 - from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...models.autoencoders import AutoencoderKL @@ -45,12 +42,21 @@ Examples: ```py >>> import torch - >>> from diffusers import FluxControlNetImg2ImgPipeline + >>> from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel >>> from diffusers.utils import load_image + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained( - ... "black-forest-labs/FLUX.1-controlnet-canny", torch_dtype=torch.bfloat16 + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 ... ) + + >>> pipe.text_encoder.to(torch.float16) + >>> pipe.controlnet.to(torch.float16) >>> pipe.to("cuda") >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") @@ -65,7 +71,7 @@ ... control_image=control_image, ... controlnet_conditioning_scale=0.6, ... strength=0.7, - ... num_inference_steps=28, + ... num_inference_steps=2, ... guidance_scale=3.5, ... ).images[0] >>> image.save("flux_controlnet_img2img.png") @@ -132,6 +138,8 @@ def retrieve_timesteps( Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. + + Examples: """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") @@ -400,12 +408,9 @@ def check_inputs( self, prompt, prompt_2, - image, strength, height, width, - control_image, - control_mode, callback_on_step_end_tensor_inputs, prompt_embeds=None, pooled_prompt_embeds=None, @@ -451,42 +456,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - # Check `image` - is_compiled = hasattr(torch.functional.F, "scaled_dot_product_attention") and isinstance( - self.transformer.modulators.attn1.processor, - AttnProcessor2_0, - ) - if ( - isinstance(image, (torch.Tensor, PIL.Image.Image, np.ndarray)) - and isinstance(control_image, (torch.Tensor, PIL.Image.Image, np.ndarray)) - and not is_compiled - ): - raise TypeError( - f"image and control_image should be passed as separate arguments. But got {type(image)} and {type(control_image)}." - ) - - # Check `control_image` - if isinstance(self.controlnet, FluxControlNetModel): - self.check_image(control_image, prompt, prompt_embeds) - elif isinstance(self.controlnet, FluxMultiControlNetModel): - if not isinstance(control_image, list): - raise TypeError("For multiple controlnets: `control_image` must be type `list`") - if len(control_image) != len(self.controlnet.nets): - raise ValueError( - "For multiple controlnets: `control_image` must have the same length as the number of controlnets." - ) - for image_ in control_image: - self.check_image(image_, prompt, prompt_embeds) - else: - assert False - - # Check `control_mode` - if control_mode is not None and isinstance(self.controlnet, FluxMultiControlNetModel): - if not isinstance(control_mode, list): - raise ValueError("You have multiple ControlNets, but only provided one control mode.") - if len(control_mode) != len(self.controlnet.nets): - raise ValueError("Number of control modes does not match the number of ControlNets.") - @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): @@ -711,6 +680,8 @@ def __call__( [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + + Examples: """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -718,12 +689,9 @@ def __call__( self.check_inputs( prompt, prompt_2, - image, strength, height, width, - control_image, - control_mode, callback_on_step_end_tensor_inputs, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, @@ -873,9 +841,6 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - guidance = torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None - guidance = guidance.expand(latents.shape[0]) if guidance is not None else None - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -883,6 +848,11 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) + guidance = ( + torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, @@ -898,6 +868,11 @@ def __call__( return_dict=False, ) + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 25b4968ffd93..922422ade3a0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -44,12 +44,22 @@ ```py >>> import torch >>> from diffusers import FluxControlNetInpaintPipeline + >>> from diffusers.models import FluxControlNetModel >>> from diffusers.utils import load_image + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( - ... "black-forest-labs/FLUX.1-controlnet-canny-inpaint", torch_dtype=torch.bfloat16 + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.bfloat16 ... ) - >>> pipe.to("cuda") + + >>> pipe.text_encoder.to(torch.float16) + >>> pipe.controlnet.to(torch.float16) + >>> pipe.to(device) >>> control_image = load_image( ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" @@ -61,13 +71,13 @@ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" ... ) - >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> prompt = "A girl holding a sign that says InstantX" >>> image = pipe( ... prompt, ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, - ... controlnet_conditioning_scale=0.6, + ... controlnet_conditioning_scale=0.7, ... strength=0.7, ... num_inference_steps=28, ... guidance_scale=3.5, @@ -409,20 +419,18 @@ def check_inputs( prompt_2, image, mask_image, - control_image, + strength, height, width, - strength, output_type, - control_mode, - callback_on_step_end_tensor_inputs=None, prompt_embeds=None, pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, max_sequence_length=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -474,32 +482,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - if isinstance(image, (torch.Tensor, PIL.Image.Image, np.ndarray)): - raise ValueError(f"image should be passed as separate argument. But got {type(image)}.") - - if mask_image is None: - raise ValueError("`mask_image` input cannot be undefined.") - - if isinstance(self.controlnet, FluxControlNetModel): - self.check_image(control_image, prompt, prompt_embeds) - elif isinstance(self.controlnet, FluxMultiControlNetModel): - if not isinstance(control_image, list): - raise TypeError("For multiple controlnets: `control_image` must be type `list`") - if len(control_image) != len(self.controlnet.nets): - raise ValueError( - "For multiple controlnets: `control_image` must have the same length as the number of controlnets." - ) - for image_ in control_image: - self.check_image(image_, prompt, prompt_embeds) - else: - assert False - - if control_mode is not None and isinstance(self.controlnet, FluxMultiControlNetModel): - if not isinstance(control_mode, list): - raise ValueError("You have multiple ControlNets, but only provided one control mode.") - if len(control_mode) != len(self.controlnet.nets): - raise ValueError("Number of control modes does not match the number of ControlNets.") - @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): @@ -810,6 +792,8 @@ def __call__( [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + + Examples: """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -820,17 +804,15 @@ def __call__( prompt_2, image, mask_image, - control_image, + strength, height, width, - strength, - output_type, - control_mode, - callback_on_step_end_tensor_inputs, - prompt_embeds, - pooled_prompt_embeds, - padding_mask_crop, - max_sequence_length, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale @@ -1026,13 +1008,6 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -1040,6 +1015,13 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) + # handle guidance + if self.controlnet.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latents, controlnet_cond=control_image, @@ -1055,6 +1037,12 @@ def __call__( return_dict=False, ) + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, From b29d96ff42de8094dd152f21b76ec61288dd0558 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Fri, 13 Sep 2024 23:11:53 +0530 Subject: [PATCH 07/11] make fix copies added --- ...pipeline_flux_controlnet_image_to_image.py | 2 -- .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 7e666a049f12..3e81f95dedaf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -138,8 +138,6 @@ def retrieve_timesteps( Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. - - Examples: """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 732488721598..97da297fa02b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -317,6 +317,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From e543d3a552e4f61eba31cdf5417e6582a8f48861 Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Sat, 14 Sep 2024 16:06:36 +0530 Subject: [PATCH 08/11] docs added for img2img and inpaint, also added docs to pipelines --- docs/source/en/api/pipelines/flux.md | 133 +++++++++++++++--- ...pipeline_flux_controlnet_image_to_image.py | 30 +++- .../pipeline_flux_controlnet_inpainting.py | 34 ++++- 3 files changed, 176 insertions(+), 21 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index e006006a3393..1fd580c42e32 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -18,22 +18,22 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b -Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c). +Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c). Flux comes in two variants: -* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) -* Guidance-distilled (`black-forest-labs/FLUX.1-dev`) +- Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) +- Guidance-distilled (`black-forest-labs/FLUX.1-dev`) Both checkpoints have slightly difference usage which we detail below. ### Timestep-distilled -* `max_sequence_length` cannot be more than 256. -* `guidance_scale` needs to be 0. -* As this is a timestep-distilled model, it benefits from fewer sampling steps. +- `max_sequence_length` cannot be more than 256. +- `guidance_scale` needs to be 0. +- As this is a timestep-distilled model, it benefits from fewer sampling steps. ```python import torch @@ -56,8 +56,8 @@ out.save("image.png") ### Guidance-distilled -* The guidance-distilled variant takes about 50 sampling steps for good-quality generation. -* It doesn't have any limitations around the `max_sequence_length`. +- The guidance-distilled variant takes about 50 sampling steps for good-quality generation. +- It doesn't have any limitations around the `max_sequence_length`. ```python import torch @@ -78,9 +78,11 @@ out.save("image.png") ``` ## Running FP16 inference + Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. FP16 inference code: + ```python import torch from diffusers import FluxPipeline @@ -160,18 +162,115 @@ image.save("flux-fp8-dev.png") ## FluxPipeline -[[autodoc]] FluxPipeline - - all - - __call__ +[[autodoc]] FluxPipeline - all - **call** ## FluxImg2ImgPipeline -[[autodoc]] FluxImg2ImgPipeline - - all - - __call__ +[[autodoc]] FluxImg2ImgPipeline - all - **call** ## FluxInpaintPipeline -[[autodoc]] FluxInpaintPipeline - - all - - __call__ +[[autodoc]] FluxInpaintPipeline - all - **call** + +## Flux ControlNet Inpaint Pipeline + +The Flux ControlNet Inpaint pipeline is designed for controllable image inpainting using the Flux architecture. + +### Overview + +This pipeline combines the power of Flux's transformer-based architecture with ControlNet conditioning and inpainting capabilities. It allows for guided image generation within specified masked areas of an input image. + +### Usage + +```python +import torch +from diffusers import FluxControlNetInpaintPipeline +from diffusers.models import FluxControlNetModel +from diffusers.utils import load_image + +device = "cuda" if torch.cuda.is_available() else "cpu" + +controlnet = FluxControlNetModel.from_pretrained( + "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 +) + +pipe = FluxControlNetInpaintPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.bfloat16 +) + +pipe.text_encoder.to(torch.float16) +pipe.controlnet.to(torch.float16) +pipe.to(device) + +control_image = load_image( + "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" +) +init_image = load_image( + "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +) +mask_image = load_image( + "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" +) + +prompt = "A girl holding a sign that says InstantX" +image = pipe( + prompt, + image=init_image, + mask_image=mask_image, + control_image=control_image, + controlnet_conditioning_scale=0.7, + strength=0.7, + num_inference_steps=28, + guidance_scale=3.5, +).images[0] + +image.save("flux_controlnet_inpaint.png") +``` + +## Flux ControlNet Image to Image Pipeline + +The Flux ControlNet Img2Img pipeline enables controllable image-to-image translation using the Flux architecture combined with ControlNet conditioning. + +### Overview + +This pipeline allows for the transformation of input images based on text prompts and ControlNet conditions. It leverages the Flux transformer-based architecture to generate high-quality output images while maintaining control over the generation process. + +### Usage + +```python +import torch +from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel +from diffusers.utils import load_image + +device = "cuda" if torch.cuda.is_available() else "cpu" + +controlnet = FluxControlNetModel.from_pretrained( + "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 +) + +pipe = FluxControlNetImg2ImgPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 +) + +pipe.text_encoder.to(torch.float16) +pipe.controlnet.to(torch.float16) +pipe.to(device) + +control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") +init_image = load_image( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +) + +prompt = "A girl in city, 25 years old, cool, futuristic" +image = pipe( + prompt, + image=init_image, + control_image=control_image, + controlnet_conditioning_scale=0.6, + strength=0.7, + num_inference_steps=2, + guidance_scale=3.5, +).images[0] + +image.save("flux_controlnet_img2img.png") +``` diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 3e81f95dedaf..72803b180c34 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -168,6 +168,32 @@ def retrieve_timesteps( class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for image-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -674,12 +700,12 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to 512): The maximum length of the sequence to be generated. + Examples: + Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. - - Examples: """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 922422ade3a0..d11dd3100d6b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -176,6 +176,36 @@ def retrieve_timesteps( class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, @@ -788,12 +818,12 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to 512): The maximum length of the sequence to be generated. + Examples: + Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. - - Examples: """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor From 245f97bc4f89060cb0f4b9bc66f89bd3680435ce Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Sun, 15 Sep 2024 22:17:51 +0530 Subject: [PATCH 09/11] Fix tests and minor bugs --- docs/source/en/api/pipelines/flux.md | 103 +------- .../pipeline_flux_controlnet_inpainting.py | 97 ++++---- .../test_controlnet_flux_inpaint.py | 233 +++++------------- .../flux/test_pipeline_flux_img2img.py | 180 ++++++++++++-- 4 files changed, 280 insertions(+), 333 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 1fd580c42e32..38e17b925fe7 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -172,105 +172,10 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxInpaintPipeline - all - **call** -## Flux ControlNet Inpaint Pipeline +## FluxControlNetInpaintPipeline -The Flux ControlNet Inpaint pipeline is designed for controllable image inpainting using the Flux architecture. +[[autodoc]] FluxControlNetInpaintPipeline - all - **call** -### Overview +## FluxControlNetImg2ImgPipeline -This pipeline combines the power of Flux's transformer-based architecture with ControlNet conditioning and inpainting capabilities. It allows for guided image generation within specified masked areas of an input image. - -### Usage - -```python -import torch -from diffusers import FluxControlNetInpaintPipeline -from diffusers.models import FluxControlNetModel -from diffusers.utils import load_image - -device = "cuda" if torch.cuda.is_available() else "cpu" - -controlnet = FluxControlNetModel.from_pretrained( - "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 -) - -pipe = FluxControlNetInpaintPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.bfloat16 -) - -pipe.text_encoder.to(torch.float16) -pipe.controlnet.to(torch.float16) -pipe.to(device) - -control_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" -) -init_image = load_image( - "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" -) -mask_image = load_image( - "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" -) - -prompt = "A girl holding a sign that says InstantX" -image = pipe( - prompt, - image=init_image, - mask_image=mask_image, - control_image=control_image, - controlnet_conditioning_scale=0.7, - strength=0.7, - num_inference_steps=28, - guidance_scale=3.5, -).images[0] - -image.save("flux_controlnet_inpaint.png") -``` - -## Flux ControlNet Image to Image Pipeline - -The Flux ControlNet Img2Img pipeline enables controllable image-to-image translation using the Flux architecture combined with ControlNet conditioning. - -### Overview - -This pipeline allows for the transformation of input images based on text prompts and ControlNet conditions. It leverages the Flux transformer-based architecture to generate high-quality output images while maintaining control over the generation process. - -### Usage - -```python -import torch -from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel -from diffusers.utils import load_image - -device = "cuda" if torch.cuda.is_available() else "cpu" - -controlnet = FluxControlNetModel.from_pretrained( - "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 -) - -pipe = FluxControlNetImg2ImgPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 -) - -pipe.text_encoder.to(torch.float16) -pipe.controlnet.to(torch.float16) -pipe.to(device) - -control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") -init_image = load_image( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" -) - -prompt = "A girl in city, 25 years old, cool, futuristic" -image = pipe( - prompt, - image=init_image, - control_image=control_image, - controlnet_conditioning_scale=0.6, - strength=0.7, - num_inference_steps=2, - guidance_scale=3.5, -).images[0] - -image.save("flux_controlnet_img2img.png") -``` +[[autodoc]] FluxControlNetImg2ImgPipeline - all - **call** diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d11dd3100d6b..d43acdf38ea5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -47,19 +47,13 @@ >>> from diffusers.models import FluxControlNetModel >>> from diffusers.utils import load_image - >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> controlnet = FluxControlNetModel.from_pretrained( - ... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16 ... ) - >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( - ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 ... ) - - >>> pipe.text_encoder.to(torch.float16) - >>> pipe.controlnet.to(torch.float16) - >>> pipe.to(device) + >>> pipe.to("cuda") >>> control_image = load_image( ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" @@ -232,7 +226,9 @@ def __init__( controlnet=controlnet, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, @@ -241,7 +237,6 @@ def __init__( do_binarize=True, do_convert_grayscale=True, ) - self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -828,6 +823,9 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + global_height = height + global_width = width + # 1. Check inputs self.check_inputs( prompt, @@ -849,21 +847,7 @@ def __call__( self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False - # 2. Preprocess mask and image - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - original_image = image - init_image = self.image_processor.preprocess( - image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image = init_image.to(dtype=torch.float32) - - # 3. Define call parameters + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -874,14 +858,11 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype + # 3. Encode input prompt lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, @@ -892,13 +873,29 @@ def __call__( lora_scale=lora_scale, ) - # 4. Prepare control image + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): control_image = self.prepare_image( image=control_image, - width=width, - height=height, + width=height, + height=width, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, @@ -969,9 +966,10 @@ def __call__( control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) control_mode = control_mode.reshape([-1, 1]) - # 5.Prepare timesteps + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -996,27 +994,25 @@ def __call__( ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # 6. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - num_channels_transformer = self.transformer.config.in_channels + # 7. Prepare latent variables latents, noise, image_latents, latent_image_ids = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, - height, - width, + global_height, + global_width, prompt_embeds.dtype, device, generator, latents, ) + # 8. Prepare mask latents mask_condition = self.mask_processor.preprocess( - mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords ) - if masked_image_latents is None: masked_image = init_image * (mask_condition < 0.5) else: @@ -1028,13 +1024,14 @@ def __call__( batch_size, num_channels_latents, num_images_per_prompt, - height, - width, + global_height, + global_width, prompt_embeds.dtype, device, generator, ) + # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -1045,7 +1042,7 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) - # handle guidance + # predict the noise residual if self.controlnet.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) @@ -1091,7 +1088,7 @@ def __call__( latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - # for 64 channel transformer only. + # For inpainting, we need to apply the mask and add the masked image latents init_latents_proper = image_latents init_mask = mask @@ -1108,6 +1105,7 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) + # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -1117,18 +1115,17 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() + # Post-processing if output_type == "latent": image = latents - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py index dff19963987d..d66eaaf6a76f 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py @@ -1,27 +1,29 @@ -import gc import random import unittest import numpy as np import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +# torch_device, # {{ edit_1 }} Removed unused import +from transformers import ( + AutoTokenizer, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, +) from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlNetInpaintPipeline, + FluxControlNetModel, FluxTransformer2DModel, ) -from diffusers.models import FluxControlNetModel -from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, - require_torch_gpu, - slow, - torch_device, ) -from diffusers.utils.torch_utils import randn_tensor from ..test_pipelines_common import PipelineTesterMixin @@ -29,30 +31,32 @@ enable_full_determinism() -class FluxControlNetInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): +class FluxControlNetInpaintPipelineTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxControlNetInpaintPipeline - params = frozenset( [ "prompt", - "image", - "mask_image", - "control_image", "height", "width", - "strength", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds", + "image", + "mask_image", + "control_image", + "strength", + "num_inference_steps", + "controlnet_conditioning_scale", ] ) batch_params = frozenset(["prompt", "image", "mask_image", "control_image"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, - in_channels=16, + in_channels=8, num_layers=1, num_single_layers=1, attention_head_dim=16, @@ -61,20 +65,6 @@ def get_dummy_components(self): pooled_projection_dim=32, axes_dims_rope=[4, 4, 8], ) - - torch.manual_seed(0) - controlnet = FluxControlNetModel( - patch_size=1, - in_channels=16, - num_layers=1, - num_single_layers=1, - attention_head_dim=16, - num_attention_heads=2, - joint_attention_dim=32, - pooled_projection_dim=32, - axes_dims_rope=[4, 4, 8], - ) - clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, @@ -88,6 +78,7 @@ def get_dummy_components(self): hidden_act="gelu", projection_dim=32, ) + torch.manual_seed(0) text_encoder = CLIPTextModel(clip_text_encoder_config) @@ -95,7 +86,7 @@ def get_dummy_components(self): text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) vae = AutoencoderKL( @@ -104,7 +95,7 @@ def get_dummy_components(self): out_channels=3, block_out_channels=(4,), layers_per_block=1, - latent_channels=4, + latent_channels=2, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, @@ -112,6 +103,19 @@ def get_dummy_components(self): scaling_factor=1.5035, ) + torch.manual_seed(0) + controlnet = FluxControlNetModel( + patch_size=1, + in_channels=8, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -129,11 +133,11 @@ def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: - generator = torch.Generator(device="cpu").manual_seed(seed) + generator = torch.Generator(device=device).manual_seed(seed) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) mask_image = torch.ones((1, 1, 32, 32)).to(device) - control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device, dtype=torch.float32) + control_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) inputs = { "prompt": "A painting of a squirrel eating a burger", @@ -142,150 +146,49 @@ def get_dummy_inputs(self, device, seed=0): "control_image": control_image, "generator": generator, "num_inference_steps": 2, - "guidance_scale": 3.5, - "output_type": "np", - "controlnet_conditioning_scale": 0.5, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, "strength": 0.8, + "output_type": "np", } - return inputs - def test_controlnet_inpaint_flux(self): + def test_flux_controlnet_inpaint_with_num_images_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - flux_pipe = FluxControlNetInpaintPipeline(**components) - flux_pipe = flux_pipe.to(torch_device) - flux_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = flux_pipe(**inputs) - image = output.images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - - expected_slice = np.array([0.5182, 0.4976, 0.4718, 0.5249, 0.5039, 0.4751, 0.5168, 0.4980, 0.4738]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - def test_attention_slicing_forward_pass(self): - components = self.get_dummy_components() - flux_pipe = FluxControlNetInpaintPipeline(**components) - flux_pipe = flux_pipe.to(torch_device) - flux_pipe.set_progress_bar_config(disable=None) - - flux_pipe.enable_attention_slicing() - inputs = self.get_dummy_inputs(torch_device) - output_sliced = flux_pipe(**inputs) - image_sliced = output_sliced.images - - flux_pipe.disable_attention_slicing() - inputs = self.get_dummy_inputs(torch_device) - output = flux_pipe(**inputs) - image = output.images - - assert np.abs(image_sliced.flatten() - image.flatten()).max() < 1e-3 - - def test_inference_batch_single_identical(self): - components = self.get_dummy_components() - flux_pipe = FluxControlNetInpaintPipeline(**components) - flux_pipe = flux_pipe.to(torch_device) - flux_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - - # make batch size 1 - inputs["prompt"] = [inputs["prompt"]] - inputs["image"] = inputs["image"][:1] - inputs["mask_image"] = inputs["mask_image"][:1] - inputs["control_image"] = inputs["control_image"][:1] - - output = flux_pipe(**inputs) - image = output.images - - inputs["prompt"] = inputs["prompt"] * 2 - inputs["image"] = torch.cat([inputs["image"], inputs["image"]]) - inputs["mask_image"] = torch.cat([inputs["mask_image"], inputs["mask_image"]]) - inputs["control_image"] = torch.cat([inputs["control_image"], inputs["control_image"]]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) - output_batch = flux_pipe(**inputs) - image_batch = output_batch.images + inputs = self.get_dummy_inputs(device) + inputs["num_images_per_prompt"] = 2 + output = pipe(**inputs) + images = output.images - assert np.abs(image_batch[0].flatten() - image[0].flatten()).max() < 1e-3 - assert np.abs(image_batch[1].flatten() - image[0].flatten()).max() < 1e-3 + assert images.shape == (2, 32, 32, 3) - def test_flux_controlnet_inpaint_prompt_embeds(self): + def test_flux_controlnet_inpaint_with_controlnet_conditioning_scale(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - flux_pipe = FluxControlNetInpaintPipeline(**components) - flux_pipe = flux_pipe.to(torch_device) - flux_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(torch_device) - output = flux_pipe(**inputs) - image = output.images[0] - - inputs = self.get_dummy_inputs(torch_device) - prompt = inputs.pop("prompt") - - (prompt_embeds, pooled_prompt_embeds, text_ids) = flux_pipe.encode_prompt(prompt, device=torch_device) - inputs["prompt_embeds"] = prompt_embeds - inputs["pooled_prompt_embeds"] = pooled_prompt_embeds - output = flux_pipe(**inputs) - image_from_embeds = output.images[0] - - assert np.abs(image - image_from_embeds).max() < 1e-3 - - -@slow -@require_torch_gpu -class FluxControlNetInpaintPipelineSlowTests(unittest.TestCase): - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def test_canny(self): - controlnet = FluxControlNetModel.from_pretrained( - "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 - ) - pipe = FluxControlNetInpaintPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 - ) - pipe.enable_model_cpu_offload() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "A girl in city, 25 years old, cool, futuristic" - control_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" - ) - init_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/init_image.png" - ) - mask_image = torch.ones((1, 1, init_image.height, init_image.width)) + inputs = self.get_dummy_inputs(device) + output_default = pipe(**inputs) + image_default = output_default.images - output = pipe( - prompt, - image=init_image, - mask_image=mask_image, - control_image=control_image, - controlnet_conditioning_scale=0.6, - strength=0.7, - num_inference_steps=3, - guidance_scale=3.5, - output_type="np", - generator=generator, - ) + inputs["controlnet_conditioning_scale"] = 0.5 + output_scaled = pipe(**inputs) + image_scaled = output_scaled.images - image = output.images[0] + # Ensure that changing the controlnet_conditioning_scale produces a different output + assert not np.allclose(image_default, image_scaled, atol=0.01) - assert image.shape == (1024, 1024, 3) + def test_attention_slicing_forward_pass(self): + super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) - image_slice = image[-3:, -3:, -1].flatten() - expected_slice = np.array([0.3242, 0.3320, 0.3359, 0.3281, 0.3398, 0.3359, 0.3086, 0.3203, 0.3203]) - assert np.abs(image_slice - expected_slice).max() < 1e-2 + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index a038b1725812..9c0e948861f7 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -1,27 +1,49 @@ -import random +import gc import unittest import numpy as np import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxControlNetImg2ImgPipeline, + FluxControlNetModel, + FluxTransformer2DModel, +) from diffusers.utils.testing_utils import ( - enable_full_determinism, - floats_tensor, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin - +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) -enable_full_determinism() +class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlNetImg2ImgPipeline + params = frozenset( + [ + "prompt", + "image", + "control_image", + "height", + "width", + "strength", + "guidance_scale", + "controlnet_conditioning_scale", + "prompt_embeds", + "pooled_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "image", "control_image"]) -class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = FluxImg2ImgPipeline - params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) - batch_params = frozenset(["prompt"]) test_xformers_attention = False def get_dummy_components(self): @@ -75,6 +97,18 @@ def get_dummy_components(self): scaling_factor=1.5035, ) + torch.manual_seed(0) + controlnet = FluxControlNetModel( + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -85,30 +119,35 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "transformer": transformer, "vae": vae, + "controlnet": controlnet, } def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) + image = torch.randn(1, 3, 32, 32).to(device) + control_image = torch.randn(1, 3, 32, 32).to(device) + inputs = { "prompt": "A painting of a squirrel eating a burger", "image": image, + "control_image": control_image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, - "height": 8, - "width": 8, - "max_sequence_length": 48, + "controlnet_conditioning_scale": 1.0, "strength": 0.8, + "height": 32, + "width": 32, + "max_sequence_length": 48, "output_type": "np", } return inputs - def test_flux_different_prompts(self): + def test_flux_controlnet_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -120,11 +159,9 @@ def test_flux_different_prompts(self): max_diff = np.abs(output_same_prompt - output_different_prompts).max() - # Outputs should be different here - # For some reasons, they don't show large differences assert max_diff > 1e-6 - def test_flux_prompt_embeds(self): + def test_flux_controlnet_prompt_embeds(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -147,3 +184,108 @@ def test_flux_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + +@slow +@require_torch_gpu +class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase): + pipeline_class = FluxControlNetImg2ImgPipeline + repo_id = "black-forest-labs/FLUX.1-schnell" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = torch.randn(1, 3, 64, 64).to(device) + control_image = torch.randn(1, 3, 64, 64).to(device) + + return { + "prompt": "A photo of a cat", + "image": image, + "control_image": control_image, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "controlnet_conditioning_scale": 1.0, + "strength": 0.8, + "output_type": "np", + "generator": generator, + } + + @unittest.skip("We cannot run inference on this model with the current CI hardware") + def test_flux_controlnet_img2img_inference(self): + pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + [0.36132812, 0.30004883, 0.25830078], + [0.36669922, 0.31103516, 0.23754883], + [0.34814453, 0.29248047, 0.23583984], + [0.35791016, 0.30981445, 0.23999023], + [0.36328125, 0.31274414, 0.2607422], + [0.37304688, 0.32177734, 0.26171875], + [0.3671875, 0.31933594, 0.25756836], + [0.36035156, 0.31103516, 0.2578125], + [0.3857422, 0.33789062, 0.27563477], + [0.3701172, 0.31982422, 0.265625], + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4 From b466a7a449c4e63b887b69a9a28f04f3a5b8e57f Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Sun, 15 Sep 2024 23:37:22 +0530 Subject: [PATCH 10/11] fix flux docs --- docs/source/en/api/pipelines/flux.md | 39 +++++++++++++++++----------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 38e17b925fe7..255c69c854bc 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -18,22 +18,22 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b -Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c). +Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c). Flux comes in two variants: -- Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) -- Guidance-distilled (`black-forest-labs/FLUX.1-dev`) +* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) +* Guidance-distilled (`black-forest-labs/FLUX.1-dev`) Both checkpoints have slightly difference usage which we detail below. ### Timestep-distilled -- `max_sequence_length` cannot be more than 256. -- `guidance_scale` needs to be 0. -- As this is a timestep-distilled model, it benefits from fewer sampling steps. +* `max_sequence_length` cannot be more than 256. +* `guidance_scale` needs to be 0. +* As this is a timestep-distilled model, it benefits from fewer sampling steps. ```python import torch @@ -56,8 +56,8 @@ out.save("image.png") ### Guidance-distilled -- The guidance-distilled variant takes about 50 sampling steps for good-quality generation. -- It doesn't have any limitations around the `max_sequence_length`. +* The guidance-distilled variant takes about 50 sampling steps for good-quality generation. +* It doesn't have any limitations around the `max_sequence_length`. ```python import torch @@ -78,11 +78,9 @@ out.save("image.png") ``` ## Running FP16 inference - Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. FP16 inference code: - ```python import torch from diffusers import FluxPipeline @@ -162,20 +160,31 @@ image.save("flux-fp8-dev.png") ## FluxPipeline -[[autodoc]] FluxPipeline - all - **call** +[[autodoc]] FluxPipeline + - all + - __call__ ## FluxImg2ImgPipeline -[[autodoc]] FluxImg2ImgPipeline - all - **call** +[[autodoc]] FluxImg2ImgPipeline + - all + - __call__ ## FluxInpaintPipeline -[[autodoc]] FluxInpaintPipeline - all - **call** +[[autodoc]] FluxInpaintPipeline + - all + - __call__ + ## FluxControlNetInpaintPipeline -[[autodoc]] FluxControlNetInpaintPipeline - all - **call** +[[autodoc]] FluxControlNetInpaintPipeline + - all + - __call__ ## FluxControlNetImg2ImgPipeline -[[autodoc]] FluxControlNetImg2ImgPipeline - all - **call** +[[autodoc]] FluxControlNetImg2ImgPipeline + - all + - __call__ From 833d3488e5114c2a38fde7b3e103473cf196c68d Mon Sep 17 00:00:00 2001 From: ighoshsubho Date: Sun, 15 Sep 2024 23:51:06 +0530 Subject: [PATCH 11/11] Flux tests fix --- .../test_controlnet_flux_img2img.py | 304 ++++++++---------- .../flux/test_pipeline_flux_img2img.py | 180 ++--------- 2 files changed, 158 insertions(+), 326 deletions(-) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index ae7832ebaaa3..9c0e948861f7 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -3,33 +3,31 @@ import numpy as np import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlNetImg2ImgPipeline, + FluxControlNetModel, FluxTransformer2DModel, ) -from diffusers.models import FluxControlNetModel -from diffusers.utils import load_image from diffusers.utils.testing_utils import ( - enable_full_determinism, + numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device, ) -from diffusers.utils.torch_utils import randn_tensor -from ..test_pipelines_common import PipelineTesterMixin - - -enable_full_determinism() +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxControlNetImg2ImgPipeline - params = frozenset( [ "prompt", @@ -39,18 +37,20 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi "width", "strength", "guidance_scale", - "num_inference_steps", + "controlnet_conditioning_scale", "prompt_embeds", "pooled_prompt_embeds", ] ) batch_params = frozenset(["prompt", "image", "control_image"]) + test_xformers_attention = False + def get_dummy_components(self): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, - in_channels=16, + in_channels=4, num_layers=1, num_single_layers=1, attention_head_dim=16, @@ -59,20 +59,6 @@ def get_dummy_components(self): pooled_projection_dim=32, axes_dims_rope=[4, 4, 8], ) - - torch.manual_seed(0) - controlnet = FluxControlNetModel( - patch_size=1, - in_channels=16, - num_layers=1, - num_single_layers=1, - attention_head_dim=16, - num_attention_heads=2, - joint_attention_dim=32, - pooled_projection_dim=32, - axes_dims_rope=[4, 4, 8], - ) - clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, @@ -86,6 +72,7 @@ def get_dummy_components(self): hidden_act="gelu", projection_dim=32, ) + torch.manual_seed(0) text_encoder = CLIPTextModel(clip_text_encoder_config) @@ -93,7 +80,7 @@ def get_dummy_components(self): text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) vae = AutoencoderKL( @@ -102,7 +89,7 @@ def get_dummy_components(self): out_channels=3, block_out_channels=(4,), layers_per_block=1, - latent_channels=4, + latent_channels=1, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, @@ -110,6 +97,18 @@ def get_dummy_components(self): scaling_factor=1.5035, ) + torch.manual_seed(0) + controlnet = FluxControlNetModel( + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -129,19 +128,8 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device="cpu").manual_seed(seed) - image = randn_tensor( - (1, 3, 32, 32), - generator=generator, - device=torch.device(device), - dtype=torch.float32, - ) - - control_image = randn_tensor( - (1, 3, 32, 32), - generator=generator, - device=torch.device(device), - dtype=torch.float32, - ) + image = torch.randn(1, 3, 32, 32).to(device) + control_image = torch.randn(1, 3, 32, 32).to(device) inputs = { "prompt": "A painting of a squirrel eating a burger", @@ -149,85 +137,99 @@ def get_dummy_inputs(self, device, seed=0): "control_image": control_image, "generator": generator, "num_inference_steps": 2, - "guidance_scale": 3.5, - "height": 8, - "width": 8, - "max_sequence_length": 48, + "guidance_scale": 5.0, + "controlnet_conditioning_scale": 1.0, "strength": 0.8, + "height": 32, + "width": 32, + "max_sequence_length": 48, "output_type": "np", - "controlnet_conditioning_scale": 0.5, } - return inputs - def test_controlnet_img2img_flux(self): - components = self.get_dummy_components() - flux_pipe = FluxControlNetImg2ImgPipeline(**components) - flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32) - flux_pipe.set_progress_bar_config(disable=None) + def test_flux_controlnet_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) - output = flux_pipe(**inputs) - image = output.images + output_same_prompt = pipe(**inputs).images[0] - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] - expected_slice = np.array([0.5182, 0.4976, 0.4718, 0.5249, 0.5039, 0.4751, 0.5168, 0.4980, 0.4738]) + max_diff = np.abs(output_same_prompt - output_different_prompts).max() - assert ( - np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + assert max_diff > 1e-6 - def test_attention_slicing_forward_pass(self): - components = self.get_dummy_components() - flux_pipe = FluxControlNetImg2ImgPipeline(**components) - flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32) - flux_pipe.set_progress_bar_config(disable=None) - - flux_pipe.enable_attention_slicing() + def test_flux_controlnet_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) - output_sliced = flux_pipe(**inputs) - image_sliced = output_sliced.images - flux_pipe.disable_attention_slicing() - inputs = self.get_dummy_inputs(torch_device) - output = flux_pipe(**inputs) - image = output.images - - assert np.abs(image_sliced.flatten() - image.flatten()).max() < 1e-3 - - def test_inference_batch_single_identical(self): - components = self.get_dummy_components() - flux_pipe = FluxControlNetImg2ImgPipeline(**components) - flux_pipe = flux_pipe.to(torch_device, dtype=torch.float32) - flux_pipe.set_progress_bar_config(disable=None) + output_with_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") - # make batch size 1 - inputs["prompt"] = [inputs["prompt"]] - inputs["image"] = inputs["image"][:1] - inputs["control_image"] = inputs["control_image"][:1] - - output = flux_pipe(**inputs) - image = output.images + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] - inputs["prompt"] = inputs["prompt"] * 2 - inputs["image"] = torch.cat([inputs["image"], inputs["image"]]) - inputs["control_image"] = torch.cat([inputs["control_image"], inputs["control_image"]]) + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 - output_batch = flux_pipe(**inputs) - image_batch = output_batch.images + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) - assert np.abs(image_batch[0].flatten() - image[0].flatten()).max() < 1e-3 - assert np.abs(image_batch[1].flatten() - image[0].flatten()).max() < 1e-3 + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." @slow @require_torch_gpu class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase): + pipeline_class = FluxControlNetImg2ImgPipeline + repo_id = "black-forest-labs/FLUX.1-schnell" + def setUp(self): super().setUp() gc.collect() @@ -238,80 +240,52 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_canny(self): - controlnet = FluxControlNetModel.from_pretrained( - "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 - ) - pipe = FluxControlNetImg2ImgPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 - ) - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "A girl in city, 25 years old, cool, futuristic" - control_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" - ) - init_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/init_image.png" - ) - - output = pipe( - prompt, - image=init_image, - control_image=control_image, - controlnet_conditioning_scale=0.6, - strength=0.7, - num_inference_steps=3, - guidance_scale=3.5, - output_type="np", - generator=generator, - ) - - image = output.images[0] + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) - assert image.shape == (1024, 1024, 3) + image = torch.randn(1, 3, 64, 64).to(device) + control_image = torch.randn(1, 3, 64, 64).to(device) - image_slice = image[-3:, -3:, -1].flatten() - expected_slice = np.array([0.3242, 0.3320, 0.3359, 0.3281, 0.3398, 0.3359, 0.3086, 0.3203, 0.3203]) - assert np.abs(image_slice - expected_slice).max() < 1e-2 + return { + "prompt": "A photo of a cat", + "image": image, + "control_image": control_image, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "controlnet_conditioning_scale": 1.0, + "strength": 0.8, + "output_type": "np", + "generator": generator, + } - def test_depth(self): - controlnet = FluxControlNetModel.from_pretrained( - "InstantX/FLUX.1-dev-Controlnet-Depth-alpha", torch_dtype=torch.bfloat16 - ) - pipe = FluxControlNetImg2ImgPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 - ) + @unittest.skip("We cannot run inference on this model with the current CI hardware") + def test_flux_controlnet_img2img_inference(self): + pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "An astronaut riding a horse on mars" - control_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Depth-alpha/resolve/main/depth.png" + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + [0.36132812, 0.30004883, 0.25830078], + [0.36669922, 0.31103516, 0.23754883], + [0.34814453, 0.29248047, 0.23583984], + [0.35791016, 0.30981445, 0.23999023], + [0.36328125, 0.31274414, 0.2607422], + [0.37304688, 0.32177734, 0.26171875], + [0.3671875, 0.31933594, 0.25756836], + [0.36035156, 0.31103516, 0.2578125], + [0.3857422, 0.33789062, 0.27563477], + [0.3701172, 0.31982422, 0.265625], + ], + dtype=np.float32, ) - init_image = load_image( - "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Depth-alpha/resolve/main/astronaut_riding_horse.png" - ) - - output = pipe( - prompt, - image=init_image, - control_image=control_image, - controlnet_conditioning_scale=0.6, - strength=0.7, - num_inference_steps=3, - guidance_scale=3.5, - output_type="np", - generator=generator, - ) - - image = output.images[0] - assert image.shape == (1024, 1024, 3) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - image_slice = image[-3:, -3:, -1].flatten() - expected_slice = np.array([0.3164, 0.3242, 0.3281, 0.3203, 0.3320, 0.3281, 0.3008, 0.3125, 0.3125]) - assert np.abs(image_slice - expected_slice).max() < 1e-2 + assert max_diff < 1e-4 diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index 9c0e948861f7..a038b1725812 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -1,49 +1,27 @@ -import gc +import random import unittest import numpy as np import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel -from diffusers import ( - AutoencoderKL, - FlowMatchEulerDiscreteScheduler, - FluxControlNetImg2ImgPipeline, - FluxControlNetModel, - FluxTransformer2DModel, -) +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import ( - numpy_cosine_similarity_distance, - require_torch_gpu, - slow, + enable_full_determinism, + floats_tensor, torch_device, ) -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin -class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = FluxControlNetImg2ImgPipeline - params = frozenset( - [ - "prompt", - "image", - "control_image", - "height", - "width", - "strength", - "guidance_scale", - "controlnet_conditioning_scale", - "prompt_embeds", - "pooled_prompt_embeds", - ] - ) - batch_params = frozenset(["prompt", "image", "control_image"]) +enable_full_determinism() + +class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxImg2ImgPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) test_xformers_attention = False def get_dummy_components(self): @@ -97,18 +75,6 @@ def get_dummy_components(self): scaling_factor=1.5035, ) - torch.manual_seed(0) - controlnet = FluxControlNetModel( - in_channels=4, - num_layers=1, - num_single_layers=1, - attention_head_dim=16, - num_attention_heads=2, - joint_attention_dim=32, - pooled_projection_dim=32, - axes_dims_rope=[4, 4, 8], - ) - scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -119,35 +85,30 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "transformer": transformer, "vae": vae, - "controlnet": controlnet, } def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) - image = torch.randn(1, 3, 32, 32).to(device) - control_image = torch.randn(1, 3, 32, 32).to(device) - inputs = { "prompt": "A painting of a squirrel eating a burger", "image": image, - "control_image": control_image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, - "controlnet_conditioning_scale": 1.0, - "strength": 0.8, - "height": 32, - "width": 32, + "height": 8, + "width": 8, "max_sequence_length": 48, + "strength": 0.8, "output_type": "np", } return inputs - def test_flux_controlnet_different_prompts(self): + def test_flux_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -159,9 +120,11 @@ def test_flux_controlnet_different_prompts(self): max_diff = np.abs(output_same_prompt - output_different_prompts).max() + # Outputs should be different here + # For some reasons, they don't show large differences assert max_diff > 1e-6 - def test_flux_controlnet_prompt_embeds(self): + def test_flux_prompt_embeds(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) @@ -184,108 +147,3 @@ def test_flux_controlnet_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 - - def test_fused_qkv_projections(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - original_image_slice = image[0, -3:, -3:, -1] - - pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice_fused = image[0, -3:, -3:, -1] - - pipe.transformer.unfuse_qkv_projections() - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice_disabled = image[0, -3:, -3:, -1] - - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." - - -@slow -@require_torch_gpu -class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase): - pipeline_class = FluxControlNetImg2ImgPipeline - repo_id = "black-forest-labs/FLUX.1-schnell" - - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def get_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - image = torch.randn(1, 3, 64, 64).to(device) - control_image = torch.randn(1, 3, 64, 64).to(device) - - return { - "prompt": "A photo of a cat", - "image": image, - "control_image": control_image, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "controlnet_conditioning_scale": 1.0, - "strength": 0.8, - "output_type": "np", - "generator": generator, - } - - @unittest.skip("We cannot run inference on this model with the current CI hardware") - def test_flux_controlnet_img2img_inference(self): - pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) - pipe.enable_model_cpu_offload() - - inputs = self.get_inputs(torch_device) - - image = pipe(**inputs).images[0] - image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - [0.36132812, 0.30004883, 0.25830078], - [0.36669922, 0.31103516, 0.23754883], - [0.34814453, 0.29248047, 0.23583984], - [0.35791016, 0.30981445, 0.23999023], - [0.36328125, 0.31274414, 0.2607422], - [0.37304688, 0.32177734, 0.26171875], - [0.3671875, 0.31933594, 0.25756836], - [0.36035156, 0.31103516, 0.2578125], - [0.3857422, 0.33789062, 0.27563477], - [0.3701172, 0.31982422, 0.265625], - ], - dtype=np.float32, - ) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - - assert max_diff < 1e-4