diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index b2772d552514..970151c682fb 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -16,11 +16,12 @@ import inspect import re import urllib.parse as ul -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Union, Tuple import torch from transformers import T5EncoderModel, T5Tokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers @@ -171,7 +172,20 @@ def retrieve_timesteps( scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps - + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg class PixArtSigmaPipeline(DiffusionPipeline): r""" @@ -196,7 +210,8 @@ class PixArtSigmaPipeline(DiffusionPipeline): _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" - + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + def __init__( self, tokenizer: T5Tokenizer, @@ -208,25 +223,38 @@ def __init__( super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + ) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 512 #from the actual tokenizer_config.json file + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 #from config.json + ) - # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 def encode_prompt( self, prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, negative_prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, clean_caption: bool = False, + do_classifier_free_guidance: bool = True, max_sequence_length: int = 300, **kwargs, ): @@ -256,30 +284,24 @@ def encode_prompt( If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. """ - if "mask_feature" in kwargs: deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + device = device or self._execution_device - if device is None: - device = self._execution_device - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - # See Section 3.1. of the paper. - max_length = max_sequence_length - if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", - max_length=max_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", @@ -326,7 +348,7 @@ def encode_prompt( uncond_input = self.tokenizer( uncond_tokens, padding="max_length", - max_length=max_length, + max_length=max_sequence_length, truncation=True, return_attention_mask=True, add_special_tokens=True, @@ -382,9 +404,10 @@ def check_inputs( height, width, negative_prompt, - callback_steps, + callback_steps=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): @@ -398,6 +421,13 @@ def check_inputs( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -584,7 +614,17 @@ def _clean_caption(self, caption): return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): shape = ( batch_size, num_channels_latents, @@ -605,21 +645,44 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: str = "", - num_inference_steps: int = 20, + negative_prompt: str = "", + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, timesteps: List[int] = None, sigmas: List[float] = None, + eta: float = 0.0, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, - height: Optional[int] = None, - width: Optional[int] = None, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -628,9 +691,12 @@ def __call__( negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, clean_caption: bool = True, + guidance_rescale: float = 0.0, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, # maybe add in logic so that if legacy and new are used at the same time, new takes priority? for now, it will error asking to use one or the other + callback_steps: int = 1, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], use_resolution_binning: bool = True, max_sequence_length: int = 300, **kwargs, @@ -663,6 +729,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): @@ -695,10 +764,20 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + called with the following arguments: `callback(self, step: int, timestep: int, latents: torch.Tensor)`. + This feature will be deprecated soon, use callback_on_step_end instead. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. + called at every step. This feature will be deprecated soon, use callback_on_step_end instead. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw @@ -716,9 +795,28 @@ def __call__( If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ + + # from diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: if self.transformer.config.sample_size == 256: aspect_ratio_bin = ASPECT_RATIO_2048_BIN @@ -737,12 +835,13 @@ def __call__( prompt, height, width, - negative_prompt, - callback_steps, - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_steps=callback_steps, # will be deprecated soon + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) # 2. Default height and width to transformer @@ -752,14 +851,13 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + + self._interrupt = False device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt ( prompt_embeds, @@ -767,26 +865,27 @@ def __call__( negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, + prompt=prompt, negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, + do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + device=device, clean_caption=clean_caption, + num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - if do_classifier_free_guidance: + + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -808,11 +907,12 @@ def __call__( added_cond_kwargs = {"resolution": None, "aspect_ratio": None} # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = t @@ -841,9 +941,13 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: @@ -852,9 +956,25 @@ def __call__( noise_pred = noise_pred # compute previous image: x_t -> x_t-1 + latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided. WILL BE DEPRECATED SOON, USE callback_on_step_end instead if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: