@@ -188,6 +188,7 @@ def apply_guidance(
188188 self ,
189189 model_output : torch .Tensor ,
190190 timestep : int = None ,
191+ latents : Optional [torch .Tensor ] = None ,
191192 ) -> torch .Tensor :
192193 if not self .do_classifier_free_guidance :
193194 return model_output
@@ -476,6 +477,7 @@ def apply_guidance(
476477 self ,
477478 model_output : torch .Tensor ,
478479 timestep : int ,
480+ latents : Optional [torch .Tensor ] = None ,
479481 ) -> torch .Tensor :
480482 if not self .do_perturbed_attention_guidance :
481483 return model_output
@@ -501,3 +503,231 @@ def apply_guidance(
501503 noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
502504
503505 return noise_pred
506+
507+
508+ class MomentumBuffer :
509+ def __init__ (self , momentum : float ):
510+ self .momentum = momentum
511+ self .running_average = 0
512+
513+ def update (self , update_value : torch .Tensor ):
514+ new_average = self .momentum * self .running_average
515+ self .running_average = update_value + new_average
516+
517+
518+ class APGGuider :
519+ """
520+ This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
521+ """
522+
523+ def normalized_guidance (
524+ self ,
525+ pred_cond : torch .Tensor ,
526+ pred_uncond : torch .Tensor ,
527+ guidance_scale : float ,
528+ momentum_buffer : MomentumBuffer = None ,
529+ norm_threshold : float = 0.0 ,
530+ eta : float = 1.0 ,
531+ ):
532+ """
533+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
534+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
535+ """
536+ diff = pred_cond - pred_uncond
537+ if momentum_buffer is not None :
538+ momentum_buffer .update (diff )
539+ diff = momentum_buffer .running_average
540+ if norm_threshold > 0 :
541+ ones = torch .ones_like (diff )
542+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
543+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
544+ diff = diff * scale_factor
545+ v0 , v1 = diff .double (), pred_cond .double ()
546+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
547+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
548+ v0_orthogonal = v0 - v0_parallel
549+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
550+ normalized_update = diff_orthogonal + eta * diff_parallel
551+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
552+ return pred_guided
553+
554+ @property
555+ def adaptive_projected_guidance_momentum (self ):
556+ return self ._adaptive_projected_guidance_momentum
557+
558+ @property
559+ def adaptive_projected_guidance_rescale_factor (self ):
560+ return self ._adaptive_projected_guidance_rescale_factor
561+
562+ @property
563+ def do_classifier_free_guidance (self ):
564+ return self ._guidance_scale > 1.0 and not self ._disable_guidance
565+
566+ @property
567+ def guidance_rescale (self ):
568+ return self ._guidance_rescale
569+
570+ @property
571+ def guidance_scale (self ):
572+ return self ._guidance_scale
573+
574+ @property
575+ def batch_size (self ):
576+ return self ._batch_size
577+
578+ def set_guider (self , pipeline , guider_kwargs : Dict [str , Any ]):
579+ disable_guidance = guider_kwargs .get ("disable_guidance" , False )
580+ guidance_scale = guider_kwargs .get ("guidance_scale" , None )
581+ if guidance_scale is None :
582+ raise ValueError ("guidance_scale is not provided in guider_kwargs" )
583+ adaptive_projected_guidance_momentum = guider_kwargs .get ("adaptive_projected_guidance_momentum" , None )
584+ adaptive_projected_guidance_rescale_factor = guider_kwargs .get (
585+ "adaptive_projected_guidance_rescale_factor" , 15.0
586+ )
587+ guidance_rescale = guider_kwargs .get ("guidance_rescale" , 0.0 )
588+ batch_size = guider_kwargs .get ("batch_size" , None )
589+ if batch_size is None :
590+ raise ValueError ("batch_size is not provided in guider_kwargs" )
591+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
592+ self ._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
593+ self ._guidance_scale = guidance_scale
594+ self ._guidance_rescale = guidance_rescale
595+ self ._batch_size = batch_size
596+ self ._disable_guidance = disable_guidance
597+ if adaptive_projected_guidance_momentum is not None :
598+ self .momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
599+ else :
600+ self .momentum_buffer = None
601+ self .scheduler = pipeline .scheduler
602+
603+ def reset_guider (self , pipeline ):
604+ pass
605+
606+ def maybe_update_guider (self , pipeline , timestep ):
607+ pass
608+
609+ def maybe_update_input (self , pipeline , cond_input ):
610+ pass
611+
612+ def _maybe_split_prepared_input (self , cond ):
613+ """
614+ Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
615+
616+ This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
617+ It determines whether to split the input based on its batch size relative to the expected batch size.
618+
619+ Args:
620+ cond (torch.Tensor): The conditional input tensor to process.
621+
622+ Returns:
623+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
624+ - The negative conditional input (uncond_input)
625+ - The positive conditional input (cond_input)
626+ """
627+ if cond .shape [0 ] == self .batch_size * 2 :
628+ neg_cond = cond [0 : self .batch_size ]
629+ cond = cond [self .batch_size :]
630+ return neg_cond , cond
631+ elif cond .shape [0 ] == self .batch_size :
632+ return cond , cond
633+ else :
634+ raise ValueError (f"Unsupported input shape: { cond .shape } " )
635+
636+ def _is_prepared_input (self , cond ):
637+ """
638+ Check if the input is already prepared for Classifier-Free Guidance (CFG).
639+
640+ Args:
641+ cond (torch.Tensor): The conditional input tensor to check.
642+
643+ Returns:
644+ bool: True if the input is already prepared, False otherwise.
645+ """
646+ cond_tensor = cond [0 ] if isinstance (cond , (list , tuple )) else cond
647+
648+ return cond_tensor .shape [0 ] == self .batch_size * 2
649+
650+ def prepare_input (
651+ self ,
652+ cond_input : Union [torch .Tensor , List [torch .Tensor ]],
653+ negative_cond_input : Optional [Union [torch .Tensor , List [torch .Tensor ]]] = None ,
654+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
655+ """
656+ Prepare the input for CFG.
657+
658+ Args:
659+ cond_input (Union[torch.Tensor, List[torch.Tensor]]):
660+ The conditional input. It can be a single tensor or a
661+ list of tensors. It must have the same length as `negative_cond_input`.
662+ negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
663+ single tensor or a list of tensors. It must have the same length as `cond_input`.
664+
665+ Returns:
666+ Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
667+ """
668+
669+ # we check if cond_input already has CFG applied, and split if it is the case.
670+ if self ._is_prepared_input (cond_input ) and self .do_classifier_free_guidance :
671+ return cond_input
672+
673+ if self ._is_prepared_input (cond_input ) and not self .do_classifier_free_guidance :
674+ if isinstance (cond_input , list ):
675+ negative_cond_input , cond_input = zip (* [self ._maybe_split_prepared_input (cond ) for cond in cond_input ])
676+ else :
677+ negative_cond_input , cond_input = self ._maybe_split_prepared_input (cond_input )
678+
679+ if not self ._is_prepared_input (cond_input ) and negative_cond_input is None :
680+ raise ValueError (
681+ "`negative_cond_input` is required when cond_input does not already contains negative conditional input"
682+ )
683+
684+ if isinstance (cond_input , (list , tuple )):
685+ if not self .do_classifier_free_guidance :
686+ return cond_input
687+
688+ if len (negative_cond_input ) != len (cond_input ):
689+ raise ValueError ("The length of negative_cond_input and cond_input must be the same." )
690+ prepared_input = []
691+ for neg_cond , cond in zip (negative_cond_input , cond_input ):
692+ if neg_cond .shape [0 ] != cond .shape [0 ]:
693+ raise ValueError ("The batch size of negative_cond_input and cond_input must be the same." )
694+ prepared_input .append (torch .cat ([neg_cond , cond ], dim = 0 ))
695+ return prepared_input
696+
697+ elif isinstance (cond_input , torch .Tensor ):
698+ if not self .do_classifier_free_guidance :
699+ return cond_input
700+ else :
701+ return torch .cat ([negative_cond_input , cond_input ], dim = 0 )
702+
703+ else :
704+ raise ValueError (f"Unsupported input type: { type (cond_input )} " )
705+
706+ def apply_guidance (
707+ self ,
708+ model_output : torch .Tensor ,
709+ timestep : int = None ,
710+ latents : Optional [torch .Tensor ] = None ,
711+ ) -> torch .Tensor :
712+ if not self .do_classifier_free_guidance :
713+ return model_output
714+
715+ if latents is None :
716+ raise ValueError ("APG requires `latents` to convert model output to denoised prediction (x0)." )
717+
718+ sigma = self .scheduler .sigmas [self .scheduler .step_index ]
719+ noise_pred = latents - sigma * model_output
720+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
721+ noise_pred = self .normalized_guidance (
722+ noise_pred_text ,
723+ noise_pred_uncond ,
724+ self .guidance_scale ,
725+ self .momentum_buffer ,
726+ self .adaptive_projected_guidance_rescale_factor ,
727+ )
728+ noise_pred = (latents - noise_pred ) / sigma
729+
730+ if self .guidance_rescale > 0.0 :
731+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
732+ noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
733+ return noise_pred
0 commit comments