Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg


Expand Down Expand Up @@ -53,6 +54,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/auto_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
Expand Down Expand Up @@ -60,6 +61,7 @@ class AutoGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg


Expand Down Expand Up @@ -67,6 +68,7 @@ class ClassifierFreeGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/classifier_free_zero_star_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg


Expand Down Expand Up @@ -58,6 +59,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
Empty file.
7 changes: 6 additions & 1 deletion src/diffusers/guiders/guider_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@

import torch

from ..configuration_utils import ConfigMixin
from ..utils import get_logger


if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState


GUIDER_CONFIG_NAME = "guider_config.json"


logger = get_logger(__name__) # pylint: disable=invalid-name


class BaseGuidance:
class BaseGuidance(ConfigMixin):
r"""Base class providing the skeleton for implementing guidance techniques."""

config_name = GUIDER_CONFIG_NAME
_input_predictions = None
_identifier_key = "__guidance_identifier__"

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/perturbed_attention_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List, Optional, Union

from ..configuration_utils import register_to_config
from ..hooks import LayerSkipConfig
from .skip_layer_guidance import SkipLayerGuidance

Expand Down Expand Up @@ -70,6 +71,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
# for each model architecture.

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/skip_layer_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
Expand Down Expand Up @@ -86,6 +87,7 @@ class SkipLayerGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
Expand Down Expand Up @@ -76,6 +77,7 @@ class SmoothedEnergyGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/guiders/tangential_classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg


Expand Down Expand Up @@ -49,6 +50,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):

_input_predictions = ["pred_cond", "pred_uncond"]

@register_to_config
def __init__(
self,
guidance_scale: float = 7.5,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,7 +1911,7 @@ def update(self, **kwargs):
loader.update(unet=new_unet_model, text_encoder=new_text_encoder)

# Update configuration values
loader.update(requires_safety_checker=False, guidance_rescale=0.7)
loader.update(requires_safety_checker=False)

# Update both components and configs together
loader.update(unet=new_unet_model, requires_safety_checker=False)
Expand Down