diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py index 96dced267baa..e58afe61574d 100644 --- a/src/diffusers/guider.py +++ b/src/diffusers/guider.py @@ -188,6 +188,7 @@ def apply_guidance( self, model_output: torch.Tensor, timestep: int = None, + latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not self.do_classifier_free_guidance: return model_output @@ -476,6 +477,7 @@ def apply_guidance( self, model_output: torch.Tensor, timestep: int, + latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not self.do_perturbed_attention_guidance: return model_output @@ -501,3 +503,231 @@ def apply_guidance( noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) return noise_pred + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +class APGGuider: + """ + This class is used to guide the pipeline with APG (Adaptive Projected Guidance). + """ + + def normalized_guidance( + self, + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: MomentumBuffer = None, + norm_threshold: float = 0.0, + eta: float = 1.0, + ): + """ + Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales + in Diffusion Models](https://arxiv.org/pdf/2410.02416) + """ + diff = pred_cond - pred_uncond + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) + normalized_update = diff_orthogonal + eta * diff_parallel + pred_guided = pred_cond + (guidance_scale - 1) * normalized_update + return pred_guided + + @property + def adaptive_projected_guidance_momentum(self): + return self._adaptive_projected_guidance_momentum + + @property + def adaptive_projected_guidance_rescale_factor(self): + return self._adaptive_projected_guidance_rescale_factor + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 and not self._disable_guidance + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def batch_size(self): + return self._batch_size + + def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): + disable_guidance = guider_kwargs.get("disable_guidance", False) + guidance_scale = guider_kwargs.get("guidance_scale", None) + if guidance_scale is None: + raise ValueError("guidance_scale is not provided in guider_kwargs") + adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) + adaptive_projected_guidance_rescale_factor = guider_kwargs.get( + "adaptive_projected_guidance_rescale_factor", 15.0 + ) + guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) + batch_size = guider_kwargs.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size is not provided in guider_kwargs") + self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._batch_size = batch_size + self._disable_guidance = disable_guidance + if adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) + else: + self.momentum_buffer = None + self.scheduler = pipeline.scheduler + + def reset_guider(self, pipeline): + pass + + def maybe_update_guider(self, pipeline, timestep): + pass + + def maybe_update_input(self, pipeline, cond_input): + pass + + def _maybe_split_prepared_input(self, cond): + """ + Process and potentially split the conditional input for Classifier-Free Guidance (CFG). + + This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). + It determines whether to split the input based on its batch size relative to the expected batch size. + + Args: + cond (torch.Tensor): The conditional input tensor to process. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The negative conditional input (uncond_input) + - The positive conditional input (cond_input) + """ + if cond.shape[0] == self.batch_size * 2: + neg_cond = cond[0 : self.batch_size] + cond = cond[self.batch_size :] + return neg_cond, cond + elif cond.shape[0] == self.batch_size: + return cond, cond + else: + raise ValueError(f"Unsupported input shape: {cond.shape}") + + def _is_prepared_input(self, cond): + """ + Check if the input is already prepared for Classifier-Free Guidance (CFG). + + Args: + cond (torch.Tensor): The conditional input tensor to check. + + Returns: + bool: True if the input is already prepared, False otherwise. + """ + cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond + + return cond_tensor.shape[0] == self.batch_size * 2 + + def prepare_input( + self, + cond_input: Union[torch.Tensor, List[torch.Tensor]], + negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Prepare the input for CFG. + + Args: + cond_input (Union[torch.Tensor, List[torch.Tensor]]): + The conditional input. It can be a single tensor or a + list of tensors. It must have the same length as `negative_cond_input`. + negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a + single tensor or a list of tensors. It must have the same length as `cond_input`. + + Returns: + Union[torch.Tensor, List[torch.Tensor]]: The prepared input. + """ + + # we check if cond_input already has CFG applied, and split if it is the case. + if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: + return cond_input + + if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: + if isinstance(cond_input, list): + negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) + else: + negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) + + if not self._is_prepared_input(cond_input) and negative_cond_input is None: + raise ValueError( + "`negative_cond_input` is required when cond_input does not already contains negative conditional input" + ) + + if isinstance(cond_input, (list, tuple)): + if not self.do_classifier_free_guidance: + return cond_input + + if len(negative_cond_input) != len(cond_input): + raise ValueError("The length of negative_cond_input and cond_input must be the same.") + prepared_input = [] + for neg_cond, cond in zip(negative_cond_input, cond_input): + if neg_cond.shape[0] != cond.shape[0]: + raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") + prepared_input.append(torch.cat([neg_cond, cond], dim=0)) + return prepared_input + + elif isinstance(cond_input, torch.Tensor): + if not self.do_classifier_free_guidance: + return cond_input + else: + return torch.cat([negative_cond_input, cond_input], dim=0) + + else: + raise ValueError(f"Unsupported input type: {type(cond_input)}") + + def apply_guidance( + self, + model_output: torch.Tensor, + timestep: int = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if not self.do_classifier_free_guidance: + return model_output + + if latents is None: + raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") + + sigma = self.scheduler.sigmas[self.scheduler.step_index] + noise_pred = latents - sigma * model_output + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = self.normalized_guidance( + noise_pred_text, + noise_pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.adaptive_projected_guidance_rescale_factor, + ) + noise_pred = (latents - noise_pred) / sigma + + if self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + return noise_pred diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 948e2b11fd1d..971f7336cf76 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -926,6 +926,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: noise_pred = pipeline.guider.apply_guidance( noise_pred, timestep=t, + latents=latents, ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -1213,7 +1214,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, )[0] # perform guidance - noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t) + noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]