Skip to content

Commit b750c69

Browse files
authored
Modular Guider ConfigMixin (#11862)
* update * update * register to config pag
1 parent 13c51bb commit b750c69

11 files changed

+23
-2
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from .guider_utils import BaseGuidance, rescale_noise_cfg
2122

2223

@@ -53,6 +54,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
5354

5455
_input_predictions = ["pred_cond", "pred_uncond"]
5556

57+
@register_to_config
5658
def __init__(
5759
self,
5860
guidance_scale: float = 7.5,

src/diffusers/guiders/auto_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from ..hooks import HookRegistry, LayerSkipConfig
2122
from ..hooks.layer_skip import _apply_layer_skip_hook
2223
from .guider_utils import BaseGuidance, rescale_noise_cfg
@@ -60,6 +61,7 @@ class AutoGuidance(BaseGuidance):
6061

6162
_input_predictions = ["pred_cond", "pred_uncond"]
6263

64+
@register_to_config
6365
def __init__(
6466
self,
6567
guidance_scale: float = 7.5,

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from .guider_utils import BaseGuidance, rescale_noise_cfg
2122

2223

@@ -67,6 +68,7 @@ class ClassifierFreeGuidance(BaseGuidance):
6768

6869
_input_predictions = ["pred_cond", "pred_uncond"]
6970

71+
@register_to_config
7072
def __init__(
7173
self,
7274
guidance_scale: float = 7.5,

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from .guider_utils import BaseGuidance, rescale_noise_cfg
2122

2223

@@ -58,6 +59,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
5859

5960
_input_predictions = ["pred_cond", "pred_uncond"]
6061

62+
@register_to_config
6163
def __init__(
6264
self,
6365
guidance_scale: float = 7.5,

src/diffusers/guiders/entropy_rectifying_guidance.py

Whitespace-only changes.

src/diffusers/guiders/guider_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616

1717
import torch
1818

19+
from ..configuration_utils import ConfigMixin
1920
from ..utils import get_logger
2021

2122

2223
if TYPE_CHECKING:
2324
from ..modular_pipelines.modular_pipeline import BlockState
2425

2526

27+
GUIDER_CONFIG_NAME = "guider_config.json"
28+
29+
2630
logger = get_logger(__name__) # pylint: disable=invalid-name
2731

2832

29-
class BaseGuidance:
33+
class BaseGuidance(ConfigMixin):
3034
r"""Base class providing the skeleton for implementing guidance techniques."""
3135

36+
config_name = GUIDER_CONFIG_NAME
3237
_input_predictions = None
3338
_identifier_key = "__guidance_identifier__"
3439

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import List, Optional, Union
1616

17+
from ..configuration_utils import register_to_config
1718
from ..hooks import LayerSkipConfig
1819
from .skip_layer_guidance import SkipLayerGuidance
1920

@@ -70,6 +71,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
7071
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
7172
# for each model architecture.
7273

74+
@register_to_config
7375
def __init__(
7476
self,
7577
guidance_scale: float = 7.5,

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from ..hooks import HookRegistry, LayerSkipConfig
2122
from ..hooks.layer_skip import _apply_layer_skip_hook
2223
from .guider_utils import BaseGuidance, rescale_noise_cfg
@@ -86,6 +87,7 @@ class SkipLayerGuidance(BaseGuidance):
8687

8788
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
8889

90+
@register_to_config
8991
def __init__(
9092
self,
9193
guidance_scale: float = 7.5,

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from ..hooks import HookRegistry
2122
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
2223
from .guider_utils import BaseGuidance, rescale_noise_cfg
@@ -76,6 +77,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
7677

7778
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
7879

80+
@register_to_config
7981
def __init__(
8082
self,
8183
guidance_scale: float = 7.5,

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from .guider_utils import BaseGuidance, rescale_noise_cfg
2122

2223

@@ -49,6 +50,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
4950

5051
_input_predictions = ["pred_cond", "pred_uncond"]
5152

53+
@register_to_config
5254
def __init__(
5355
self,
5456
guidance_scale: float = 7.5,

0 commit comments

Comments
 (0)