Skip to content

Commit 676e672

Browse files
committed
update
1 parent 3e46c86 commit 676e672

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from ..configuration_utils import register_to_config
2021
from .guider_utils import BaseGuidance, rescale_noise_cfg
2122

2223

@@ -67,6 +68,7 @@ class ClassifierFreeGuidance(BaseGuidance):
6768

6869
_input_predictions = ["pred_cond", "pred_uncond"]
6970

71+
@register_to_config
7072
def __init__(
7173
self,
7274
guidance_scale: float = 7.5,

src/diffusers/guiders/guider_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616

1717
import torch
1818

19+
from ..configuration_utils import ConfigMixin
1920
from ..utils import get_logger
2021

2122

2223
if TYPE_CHECKING:
2324
from ..modular_pipelines.modular_pipeline import BlockState
2425

2526

27+
GUIDER_CONFIG_NAME = "guider_config.json"
28+
29+
2630
logger = get_logger(__name__) # pylint: disable=invalid-name
2731

2832

29-
class BaseGuidance:
33+
class BaseGuidance(ConfigMixin):
3034
r"""Base class providing the skeleton for implementing guidance techniques."""
3135

36+
config_name = GUIDER_CONFIG_NAME
3237
_input_predictions = None
3338
_identifier_key = "__guidance_identifier__"
3439

0 commit comments

Comments
 (0)