|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from typing import List, Optional, Union |
| 15 | +from typing import Any, Dict, List, Optional, Union |
16 | 16 |
|
17 | 17 | from ..configuration_utils import register_to_config |
18 | 18 | from ..hooks import LayerSkipConfig |
| 19 | +from ..utils import get_logger |
19 | 20 | from .skip_layer_guidance import SkipLayerGuidance |
20 | 21 |
|
21 | 22 |
|
| 23 | +logger = get_logger(__name__) # pylint: disable=invalid-name |
| 24 | + |
| 25 | + |
22 | 26 | class PerturbedAttentionGuidance(SkipLayerGuidance): |
23 | 27 | """ |
24 | 28 | Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 |
@@ -48,8 +52,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): |
48 | 52 | The fraction of the total number of denoising steps after which perturbed attention guidance stops. |
49 | 53 | perturbed_guidance_layers (`int` or `List[int]`, *optional*): |
50 | 54 | The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers. |
51 | | - If not provided, `skip_layer_config` must be provided. |
52 | | - skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): |
| 55 | + If not provided, `perturbed_guidance_config` must be provided. |
| 56 | + perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): |
53 | 57 | The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of |
54 | 58 | `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided. |
55 | 59 | guidance_rescale (`float`, defaults to `0.0`): |
@@ -79,36 +83,60 @@ def __init__( |
79 | 83 | perturbed_guidance_start: float = 0.01, |
80 | 84 | perturbed_guidance_stop: float = 0.2, |
81 | 85 | perturbed_guidance_layers: Optional[Union[int, List[int]]] = None, |
82 | | - skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, |
| 86 | + perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, |
83 | 87 | guidance_rescale: float = 0.0, |
84 | 88 | use_original_formulation: bool = False, |
85 | 89 | start: float = 0.0, |
86 | 90 | stop: float = 1.0, |
87 | 91 | ): |
88 | | - if skip_layer_config is None: |
| 92 | + if perturbed_guidance_config is None: |
89 | 93 | if perturbed_guidance_layers is None: |
90 | 94 | raise ValueError( |
91 | | - "`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified." |
| 95 | + "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified." |
92 | 96 | ) |
93 | | - skip_layer_config = LayerSkipConfig( |
| 97 | + perturbed_guidance_config = LayerSkipConfig( |
94 | 98 | indices=perturbed_guidance_layers, |
| 99 | + fqn="auto", |
95 | 100 | skip_attention=False, |
96 | 101 | skip_attention_scores=True, |
97 | 102 | skip_ff=False, |
98 | 103 | ) |
99 | 104 | else: |
100 | 105 | if perturbed_guidance_layers is not None: |
101 | 106 | raise ValueError( |
102 | | - "`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified." |
| 107 | + "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified." |
| 108 | + ) |
| 109 | + |
| 110 | + if isinstance(perturbed_guidance_config, dict): |
| 111 | + perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config) |
| 112 | + |
| 113 | + if isinstance(perturbed_guidance_config, LayerSkipConfig): |
| 114 | + perturbed_guidance_config = [perturbed_guidance_config] |
| 115 | + |
| 116 | + if not isinstance(perturbed_guidance_config, list): |
| 117 | + raise ValueError( |
| 118 | + "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`." |
| 119 | + ) |
| 120 | + elif isinstance(next(iter(perturbed_guidance_config), None), dict): |
| 121 | + perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config] |
| 122 | + |
| 123 | + for config in perturbed_guidance_config: |
| 124 | + if config.skip_attention or not config.skip_attention_scores or config.skip_ff: |
| 125 | + logger.warning( |
| 126 | + "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. " |
| 127 | + "Please check your configuration. Modifying the config to match the expected values." |
103 | 128 | ) |
| 129 | + config.skip_attention = False |
| 130 | + config.skip_attention_scores = True |
| 131 | + config.skip_ff = False |
104 | 132 |
|
105 | 133 | super().__init__( |
106 | 134 | guidance_scale=guidance_scale, |
107 | 135 | skip_layer_guidance_scale=perturbed_guidance_scale, |
108 | 136 | skip_layer_guidance_start=perturbed_guidance_start, |
109 | 137 | skip_layer_guidance_stop=perturbed_guidance_stop, |
110 | 138 | skip_layer_guidance_layers=perturbed_guidance_layers, |
111 | | - skip_layer_config=skip_layer_config, |
| 139 | + skip_layer_config=perturbed_guidance_config, |
112 | 140 | guidance_rescale=guidance_rescale, |
113 | 141 | use_original_formulation=use_original_formulation, |
114 | 142 | start=start, |
|
0 commit comments