Skip to content

Commit 284f827

Browse files
authored
Modular custom config object serialization (#11868)
* update * make style
1 parent b750c69 commit 284f827

File tree

7 files changed

+76
-15
lines changed

7 files changed

+76
-15
lines changed

src/diffusers/configuration_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,10 @@ def to_json_saveable(value):
601601
value = value.tolist()
602602
elif isinstance(value, Path):
603603
value = value.as_posix()
604+
elif hasattr(value, "to_dict") and callable(value.to_dict):
605+
value = value.to_dict()
606+
elif isinstance(value, list):
607+
value = [to_json_saveable(v) for v in value]
604608
return value
605609

606610
if "quantization_config" in config_dict:

src/diffusers/guiders/auto_guidance.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -66,7 +66,7 @@ def __init__(
6666
self,
6767
guidance_scale: float = 7.5,
6868
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
69-
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
69+
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
7070
dropout: Optional[float] = None,
7171
guidance_rescale: float = 0.0,
7272
use_original_formulation: bool = False,
@@ -104,13 +104,18 @@ def __init__(
104104
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
105105
]
106106

107+
if isinstance(auto_guidance_config, dict):
108+
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
109+
107110
if isinstance(auto_guidance_config, LayerSkipConfig):
108111
auto_guidance_config = [auto_guidance_config]
109112

110113
if not isinstance(auto_guidance_config, list):
111114
raise ValueError(
112115
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
113116
)
117+
elif isinstance(next(iter(auto_guidance_config), None), dict):
118+
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
114119

115120
self.auto_guidance_config = auto_guidance_config
116121
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Union
15+
from typing import Any, Dict, List, Optional, Union
1616

1717
from ..configuration_utils import register_to_config
1818
from ..hooks import LayerSkipConfig
19+
from ..utils import get_logger
1920
from .skip_layer_guidance import SkipLayerGuidance
2021

2122

23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
2226
class PerturbedAttentionGuidance(SkipLayerGuidance):
2327
"""
2428
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
@@ -48,8 +52,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
4852
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
4953
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
5054
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
51-
If not provided, `skip_layer_config` must be provided.
52-
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
55+
If not provided, `perturbed_guidance_config` must be provided.
56+
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
5357
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
5458
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
5559
guidance_rescale (`float`, defaults to `0.0`):
@@ -79,36 +83,60 @@ def __init__(
7983
perturbed_guidance_start: float = 0.01,
8084
perturbed_guidance_stop: float = 0.2,
8185
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
82-
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
86+
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
8387
guidance_rescale: float = 0.0,
8488
use_original_formulation: bool = False,
8589
start: float = 0.0,
8690
stop: float = 1.0,
8791
):
88-
if skip_layer_config is None:
92+
if perturbed_guidance_config is None:
8993
if perturbed_guidance_layers is None:
9094
raise ValueError(
91-
"`perturbed_guidance_layers` must be provided if `skip_layer_config` is not specified."
95+
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
9296
)
93-
skip_layer_config = LayerSkipConfig(
97+
perturbed_guidance_config = LayerSkipConfig(
9498
indices=perturbed_guidance_layers,
99+
fqn="auto",
95100
skip_attention=False,
96101
skip_attention_scores=True,
97102
skip_ff=False,
98103
)
99104
else:
100105
if perturbed_guidance_layers is not None:
101106
raise ValueError(
102-
"`perturbed_guidance_layers` should not be provided if `skip_layer_config` is specified."
107+
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
108+
)
109+
110+
if isinstance(perturbed_guidance_config, dict):
111+
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
112+
113+
if isinstance(perturbed_guidance_config, LayerSkipConfig):
114+
perturbed_guidance_config = [perturbed_guidance_config]
115+
116+
if not isinstance(perturbed_guidance_config, list):
117+
raise ValueError(
118+
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
119+
)
120+
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
121+
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
122+
123+
for config in perturbed_guidance_config:
124+
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
125+
logger.warning(
126+
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
127+
"Please check your configuration. Modifying the config to match the expected values."
103128
)
129+
config.skip_attention = False
130+
config.skip_attention_scores = True
131+
config.skip_ff = False
104132

105133
super().__init__(
106134
guidance_scale=guidance_scale,
107135
skip_layer_guidance_scale=perturbed_guidance_scale,
108136
skip_layer_guidance_start=perturbed_guidance_start,
109137
skip_layer_guidance_stop=perturbed_guidance_stop,
110138
skip_layer_guidance_layers=perturbed_guidance_layers,
111-
skip_layer_config=skip_layer_config,
139+
skip_layer_config=perturbed_guidance_config,
112140
guidance_rescale=guidance_rescale,
113141
use_original_formulation=use_original_formulation,
114142
start=start,

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -95,7 +95,7 @@ def __init__(
9595
skip_layer_guidance_start: float = 0.01,
9696
skip_layer_guidance_stop: float = 0.2,
9797
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
98-
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
98+
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
9999
guidance_rescale: float = 0.0,
100100
use_original_formulation: bool = False,
101101
start: float = 0.0,
@@ -135,13 +135,18 @@ def __init__(
135135
)
136136
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
137137

138+
if isinstance(skip_layer_config, dict):
139+
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
140+
138141
if isinstance(skip_layer_config, LayerSkipConfig):
139142
skip_layer_config = [skip_layer_config]
140143

141144
if not isinstance(skip_layer_config, list):
142145
raise ValueError(
143146
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
144147
)
148+
elif isinstance(next(iter(skip_layer_config), None), dict):
149+
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
145150

146151
self.skip_layer_config = skip_layer_config
147152
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,18 @@ def __init__(
125125
)
126126
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
127127

128+
if isinstance(seg_guidance_config, dict):
129+
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
130+
128131
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
129132
seg_guidance_config = [seg_guidance_config]
130133

131134
if not isinstance(seg_guidance_config, list):
132135
raise ValueError(
133136
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
134137
)
138+
elif isinstance(next(iter(seg_guidance_config), None), dict):
139+
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
135140

136141
self.seg_guidance_config = seg_guidance_config
137142
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]

src/diffusers/hooks/layer_skip.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from dataclasses import dataclass
16+
from dataclasses import asdict, dataclass
1717
from typing import Callable, List, Optional
1818

1919
import torch
@@ -78,6 +78,13 @@ def __post_init__(self):
7878
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
7979
)
8080

81+
def to_dict(self):
82+
return asdict(self)
83+
84+
@staticmethod
85+
def from_dict(data: dict) -> "LayerSkipConfig":
86+
return LayerSkipConfig(**data)
87+
8188

8289
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
8390
def __torch_function__(self, func, types, args=(), kwargs=None):

src/diffusers/hooks/smoothed_energy_guidance_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from dataclasses import dataclass
16+
from dataclasses import asdict, dataclass
1717
from typing import List, Optional
1818

1919
import torch
@@ -51,6 +51,13 @@ class SmoothedEnergyGuidanceConfig:
5151
fqn: str = "auto"
5252
_query_proj_identifiers: List[str] = None
5353

54+
def to_dict(self):
55+
return asdict(self)
56+
57+
@staticmethod
58+
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
59+
return SmoothedEnergyGuidanceConfig(**data)
60+
5461

5562
class SmoothedEnergyGuidanceHook(ModelHook):
5663
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:

0 commit comments

Comments
 (0)