Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -92,7 +92,7 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if not self._is_apg_enabled():
Expand All @@ -111,7 +111,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/auto_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -145,7 +145,7 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if not self._is_ag_enabled():
Expand All @@ -158,7 +158,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -96,7 +96,7 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if not self._is_cfg_enabled():
Expand All @@ -109,7 +109,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/classifier_free_zero_star_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -89,7 +89,7 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if self._step < self.zero_init_steps:
Expand All @@ -109,7 +109,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/frequency_decoupled_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..configuration_utils import register_to_config
from ..utils import is_kornia_available
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -230,7 +230,7 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if not self._is_fdg_enabled():
Expand Down Expand Up @@ -277,7 +277,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/guiders/guider_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import Self

from ..configuration_utils import ConfigMixin
from ..utils import PushToHubMixin, get_logger
from ..utils import BaseOutput, PushToHubMixin, get_logger


if TYPE_CHECKING:
Expand Down Expand Up @@ -284,6 +284,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)


class GuiderOutput(BaseOutput):
pred: torch.Tensor
pred_cond: Optional[torch.Tensor]
pred_uncond: Optional[torch.Tensor]


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/perturbed_attention_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -197,7 +197,7 @@ def forward(
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> GuiderOutput:
pred = None

if not self._is_cfg_enabled() and not self._is_slg_enabled():
Expand All @@ -219,7 +219,7 @@ def forward(
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/skip_layer_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -192,7 +192,7 @@ def forward(
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> GuiderOutput:
pred = None

if not self._is_cfg_enabled() and not self._is_slg_enabled():
Expand All @@ -214,7 +214,7 @@ def forward(
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -181,7 +181,7 @@ def forward(
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> GuiderOutput:
pred = None

if not self._is_cfg_enabled() and not self._is_seg_enabled():
Expand All @@ -203,7 +203,7 @@ def forward(
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/guiders/tangential_classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -78,7 +78,7 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

if not self._is_tcfg_enabled():
Expand All @@ -89,7 +89,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)

@property
def is_conditional(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __call__(
components.guider.cleanup_models(components.unet)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
block_state.noise_pred = components.guider(guider_state)[0]

return components, block_state

Expand Down Expand Up @@ -433,7 +433,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
components.guider.cleanup_models(components.unet)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
block_state.noise_pred = components.guider(guider_state)[0]

return components, block_state

Expand Down Expand Up @@ -492,7 +492,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]

Expand Down Expand Up @@ -590,7 +589,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]

Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/modular_pipelines/wan/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __call__(
components.guider.cleanup_models(components.transformer)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
block_state.noise_pred = components.guider(guider_state)[0]

return components, block_state

Expand Down Expand Up @@ -171,7 +171,6 @@ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: i
block_state.noise_pred.float(),
t,
block_state.latents.float(),
**block_state.scheduler_step_kwargs,
return_dict=False,
)[0]

Expand Down
Loading