diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f9b652bbc021..770c949ffb3d 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -601,6 +601,10 @@ def to_json_saveable(value): value = value.tolist() elif isinstance(value, Path): value = value.as_posix() + elif hasattr(value, "to_dict") and callable(value.to_dict): + value = value.to_dict() + elif isinstance(value, list): + value = [to_json_saveable(v) for v in value] return value if "quantization_config" in config_dict: diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 159354559966..e1642211d393 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -66,7 +66,7 @@ def __init__( self, guidance_scale: float = 7.5, auto_guidance_layers: Optional[Union[int, List[int]]] = None, - auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, dropout: Optional[float] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, @@ -104,6 +104,9 @@ def __init__( LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers ] + if isinstance(auto_guidance_config, dict): + auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config) + if isinstance(auto_guidance_config, LayerSkipConfig): auto_guidance_config = [auto_guidance_config] @@ -111,6 +114,8 @@ def __init__( raise ValueError( f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." ) + elif isinstance(next(iter(auto_guidance_config), None), dict): + auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config] self.auto_guidance_config = auto_guidance_config self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index dbba904d0bde..3045f2feaae2 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from ..configuration_utils import register_to_config from ..hooks import LayerSkipConfig +from ..utils import get_logger from .skip_layer_guidance import SkipLayerGuidance +logger = get_logger(__name__) # pylint: disable=invalid-name + + class PerturbedAttentionGuidance(SkipLayerGuidance): """ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 @@ -48,8 +52,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance): The fraction of the total number of denoising steps after which perturbed attention guidance stops. perturbed_guidance_layers (`int` or `List[int]`, *optional*): The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers. - If not provided, `skip_layer_config` must be provided. - skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + If not provided, `perturbed_guidance_config` must be provided. + perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided. guidance_rescale (`float`, defaults to `0.0`): @@ -79,19 +83,20 @@ def __init__( perturbed_guidance_start: float = 0.01, perturbed_guidance_stop: float = 0.2, perturbed_guidance_layers: Optional[Union[int, List[int]]] = None, - skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0, ): - if skip_layer_config is None: + if perturbed_guidance_config is None: if perturbed_guidance_layers is None: raise ValueError( - "`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified." + "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified." ) - skip_layer_config = LayerSkipConfig( + perturbed_guidance_config = LayerSkipConfig( indices=perturbed_guidance_layers, + fqn="auto", skip_attention=False, skip_attention_scores=True, skip_ff=False, @@ -99,8 +104,31 @@ def __init__( else: if perturbed_guidance_layers is not None: raise ValueError( - "`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified." + "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified." + ) + + if isinstance(perturbed_guidance_config, dict): + perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config) + + if isinstance(perturbed_guidance_config, LayerSkipConfig): + perturbed_guidance_config = [perturbed_guidance_config] + + if not isinstance(perturbed_guidance_config, list): + raise ValueError( + "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`." + ) + elif isinstance(next(iter(perturbed_guidance_config), None), dict): + perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config] + + for config in perturbed_guidance_config: + if config.skip_attention or not config.skip_attention_scores or config.skip_ff: + logger.warning( + "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. " + "Please check your configuration. Modifying the config to match the expected values." ) + config.skip_attention = False + config.skip_attention_scores = True + config.skip_ff = False super().__init__( guidance_scale=guidance_scale, @@ -108,7 +136,7 @@ def __init__( skip_layer_guidance_start=perturbed_guidance_start, skip_layer_guidance_stop=perturbed_guidance_stop, skip_layer_guidance_layers=perturbed_guidance_layers, - skip_layer_config=skip_layer_config, + skip_layer_config=perturbed_guidance_config, guidance_rescale=guidance_rescale, use_original_formulation=use_original_formulation, start=start, diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index e67b20df19fa..68a657960a45 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -95,7 +95,7 @@ def __init__( skip_layer_guidance_start: float = 0.01, skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, - skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, @@ -135,6 +135,9 @@ def __init__( ) skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + if isinstance(skip_layer_config, dict): + skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config) + if isinstance(skip_layer_config, LayerSkipConfig): skip_layer_config = [skip_layer_config] @@ -142,6 +145,8 @@ def __init__( raise ValueError( f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." ) + elif isinstance(next(iter(skip_layer_config), None), dict): + skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config] self.skip_layer_config = skip_layer_config self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 66c46064d46d..d8e8a3cf2fa8 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -125,6 +125,9 @@ def __init__( ) seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + if isinstance(seg_guidance_config, dict): + seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config) + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): seg_guidance_config = [seg_guidance_config] @@ -132,6 +135,8 @@ def __init__( raise ValueError( f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." ) + elif isinstance(next(iter(seg_guidance_config), None), dict): + seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config] self.seg_guidance_config = seg_guidance_config self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 32c9f205d683..487a1876d605 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Callable, List, Optional import torch @@ -78,6 +78,13 @@ def __post_init__(self): "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." ) + def to_dict(self): + return asdict(self) + + @staticmethod + def from_dict(data: dict) -> "LayerSkipConfig": + return LayerSkipConfig(**data) + class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index 65cce3c53907..622f60764762 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import List, Optional import torch @@ -51,6 +51,13 @@ class SmoothedEnergyGuidanceConfig: fqn: str = "auto" _query_proj_identifiers: List[str] = None + def to_dict(self): + return asdict(self) + + @staticmethod + def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig": + return SmoothedEnergyGuidanceConfig(**data) + class SmoothedEnergyGuidanceHook(ModelHook): def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: