diff --git a/examples/community/README.md b/examples/community/README.md index 8f4ab80d680b..e51124e75956 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Example | Description | Code Example | Colab | Author | |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| +|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)| |Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)| | HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) | | Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) | @@ -82,6 +83,36 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion ## Example usages +### Flux with CFG + +Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md). + +Example usage: + +```py +from diffusers import DiffusionPipeline +import torch + +pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, + custom_pipeline="pipeline_flux_with_cfg" +) +pipeline.enable_model_cpu_offload() +prompt = "a watercolor painting of a unicorn" +negative_prompt = "pink" + +img = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + true_cfg=1.5, + guidance_scale=3.5, + num_images_per_prompt=1, + generator=torch.manual_seed(0) +).images[0] +img.save("cfg_flux.png") +``` + ### Differential Diffusion **Eran Levin, Ohad Fried** diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py index 7cfa7b728980..06da6da899cd 100644 --- a/examples/community/pipeline_flux_with_cfg.py +++ b/examples/community/pipeline_flux_with_cfg.py @@ -289,80 +289,104 @@ def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, + do_true_cfg: bool = False, ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ device = device or self._execution_device - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it + # Set LoRA scale if applicable if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale - # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if do_true_cfg and negative_prompt is not None: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_batch_size = len(negative_prompt) + + if negative_batch_size != batch_size: + raise ValueError( + f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})" + ) + + # Concatenate prompts + prompts = prompt + negative_prompt + prompts_2 = ( + prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None + ) + else: + prompts = prompt + prompts_2 = prompt_2 if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + if prompts_2 is None: + prompts_2 = prompts - # We only use the pooled prompt output from the CLIPTextModel + # Get pooled prompt embeddings from CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, + prompt=prompts, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, + prompt=prompts_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) + if do_true_cfg and negative_prompt is not None: + # Split embeddings back into positive and negative parts + total_batch_size = batch_size * num_images_per_prompt + positive_indices = slice(0, total_batch_size) + negative_indices = slice(total_batch_size, 2 * total_batch_size) + + positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices] + negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices] + + positive_prompt_embeds = prompt_embeds[positive_indices] + negative_prompt_embeds = prompt_embeds[negative_indices] + + pooled_prompt_embeds = positive_pooled_prompt_embeds + prompt_embeds = positive_prompt_embeds + + # Unscale LoRA layers if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - return prompt_embeds, pooled_prompt_embeds, text_ids + if do_true_cfg and negative_prompt is not None: + return ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ) + else: + return prompt_embeds, pooled_prompt_embeds, text_ids, None, None def check_inputs( self, @@ -687,38 +711,33 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) + do_true_cfg = true_cfg > 1 and negative_prompt is not None ( prompt_embeds, pooled_prompt_embeds, text_ids, + negative_prompt_embeds, + negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, + do_true_cfg=do_true_cfg, ) - # perform "real" CFG as suggested for distilled Flux models in https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md - do_true_cfg = true_cfg > 1 and negative_prompt is not None if do_true_cfg: - ( - negative_prompt_embeds, - negative_pooled_prompt_embeds, - negative_text_ids, - ) = self.encode_prompt( - prompt=negative_prompt, - prompt_2=negative_prompt_2, - prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) + # Concatenate embeddings + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -754,24 +773,26 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue + latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) noise_pred = self.transformer( - hidden_states=latents, + hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, @@ -783,18 +804,7 @@ def __call__( )[0] if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - + neg_noise_pred, noise_pred = noise_pred.chunk(2) noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1