|  | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
| 15 |  | -from typing import List, Optional, Union | 
|  | 15 | +from typing import Any, Dict, List, Optional, Union | 
| 16 | 16 | 
 | 
| 17 | 17 | from ..configuration_utils import register_to_config | 
| 18 | 18 | from ..hooks import LayerSkipConfig | 
|  | 19 | +from ..utils import get_logger | 
| 19 | 20 | from .skip_layer_guidance import SkipLayerGuidance | 
| 20 | 21 | 
 | 
| 21 | 22 | 
 | 
|  | 23 | +logger = get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 24 | + | 
|  | 25 | + | 
| 22 | 26 | class PerturbedAttentionGuidance(SkipLayerGuidance): | 
| 23 | 27 |     """ | 
| 24 | 28 |     Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 | 
| @@ -48,8 +52,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): | 
| 48 | 52 |             The fraction of the total number of denoising steps after which perturbed attention guidance stops. | 
| 49 | 53 |         perturbed_guidance_layers (`int` or `List[int]`, *optional*): | 
| 50 | 54 |             The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers. | 
| 51 |  | -            If not provided, `skip_layer_config` must be provided. | 
| 52 |  | -        skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): | 
|  | 55 | +            If not provided, `perturbed_guidance_config` must be provided. | 
|  | 56 | +        perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): | 
| 53 | 57 |             The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of | 
| 54 | 58 |             `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided. | 
| 55 | 59 |         guidance_rescale (`float`, defaults to `0.0`): | 
| @@ -79,36 +83,62 @@ def __init__( | 
| 79 | 83 |         perturbed_guidance_start: float = 0.01, | 
| 80 | 84 |         perturbed_guidance_stop: float = 0.2, | 
| 81 | 85 |         perturbed_guidance_layers: Optional[Union[int, List[int]]] = None, | 
| 82 |  | -        skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, | 
|  | 86 | +        perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, | 
| 83 | 87 |         guidance_rescale: float = 0.0, | 
| 84 | 88 |         use_original_formulation: bool = False, | 
| 85 | 89 |         start: float = 0.0, | 
| 86 | 90 |         stop: float = 1.0, | 
| 87 | 91 |     ): | 
| 88 |  | -        if skip_layer_config is None: | 
|  | 92 | +        if perturbed_guidance_config is None: | 
| 89 | 93 |             if perturbed_guidance_layers is None: | 
| 90 | 94 |                 raise ValueError( | 
| 91 |  | -                    "`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified." | 
|  | 95 | +                    "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified." | 
| 92 | 96 |                 ) | 
| 93 |  | -            skip_layer_config = LayerSkipConfig( | 
|  | 97 | +            perturbed_guidance_config = LayerSkipConfig( | 
| 94 | 98 |                 indices=perturbed_guidance_layers, | 
|  | 99 | +                fqn="auto", | 
| 95 | 100 |                 skip_attention=False, | 
| 96 | 101 |                 skip_attention_scores=True, | 
| 97 | 102 |                 skip_ff=False, | 
| 98 | 103 |             ) | 
| 99 | 104 |         else: | 
| 100 | 105 |             if perturbed_guidance_layers is not None: | 
| 101 | 106 |                 raise ValueError( | 
| 102 |  | -                    "`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified." | 
|  | 107 | +                    "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified." | 
|  | 108 | +                ) | 
|  | 109 | + | 
|  | 110 | +        if isinstance(perturbed_guidance_config, dict): | 
|  | 111 | +            perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config) | 
|  | 112 | + | 
|  | 113 | +        if isinstance(perturbed_guidance_config, LayerSkipConfig): | 
|  | 114 | +            perturbed_guidance_config = [perturbed_guidance_config] | 
|  | 115 | + | 
|  | 116 | +        if not isinstance(perturbed_guidance_config, list): | 
|  | 117 | +            raise ValueError( | 
|  | 118 | +                "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`." | 
|  | 119 | +            ) | 
|  | 120 | +        elif isinstance(next(iter(perturbed_guidance_config), None), dict): | 
|  | 121 | +                perturbed_guidance_config = [ | 
|  | 122 | +                    LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config | 
|  | 123 | +                ] | 
|  | 124 | + | 
|  | 125 | +        for config in perturbed_guidance_config: | 
|  | 126 | +            if config.skip_attention or not config.skip_attention_scores or config.skip_ff: | 
|  | 127 | +                logger.warning( | 
|  | 128 | +                    "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. " | 
|  | 129 | +                    "Please check your configuration. Modifying the config to match the expected values." | 
| 103 | 130 |                 ) | 
|  | 131 | +            config.skip_attention = False | 
|  | 132 | +            config.skip_attention_scores = True | 
|  | 133 | +            config.skip_ff = False | 
| 104 | 134 | 
 | 
| 105 | 135 |         super().__init__( | 
| 106 | 136 |             guidance_scale=guidance_scale, | 
| 107 | 137 |             skip_layer_guidance_scale=perturbed_guidance_scale, | 
| 108 | 138 |             skip_layer_guidance_start=perturbed_guidance_start, | 
| 109 | 139 |             skip_layer_guidance_stop=perturbed_guidance_stop, | 
| 110 | 140 |             skip_layer_guidance_layers=perturbed_guidance_layers, | 
| 111 |  | -            skip_layer_config=skip_layer_config, | 
|  | 141 | +            skip_layer_config=perturbed_guidance_config, | 
| 112 | 142 |             guidance_rescale=guidance_rescale, | 
| 113 | 143 |             use_original_formulation=use_original_formulation, | 
| 114 | 144 |             start=start, | 
|  | 
0 commit comments