diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 81137db106a0..92b1fd5a1c2c 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -18,7 +18,7 @@ import torch from ..configuration_utils import register_to_config -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -92,7 +92,7 @@ def prepare_inputs( data_batches.append(data_batch) return data_batches - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None if not self._is_apg_enabled(): @@ -111,7 +111,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index e1642211d393..8f4d7b11c942 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -20,7 +20,7 @@ from ..configuration_utils import register_to_config from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -145,7 +145,7 @@ def prepare_inputs( data_batches.append(data_batch) return data_batches - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None if not self._is_ag_enabled(): @@ -158,7 +158,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 7e72b92fcee2..050590336ffb 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -18,7 +18,7 @@ import torch from ..configuration_utils import register_to_config -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -96,7 +96,7 @@ def prepare_inputs( data_batches.append(data_batch) return data_batches - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None if not self._is_cfg_enabled(): @@ -109,7 +109,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 85d5cc62d4e7..b64e35633114 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -18,7 +18,7 @@ import torch from ..configuration_utils import register_to_config -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -89,7 +89,7 @@ def prepare_inputs( data_batches.append(data_batch) return data_batches - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None if self._step < self.zero_init_steps: @@ -109,7 +109,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 35bc99ac4dde..2bf2f430b1b3 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -19,7 +19,7 @@ from ..configuration_utils import register_to_config from ..utils import is_kornia_available -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -230,7 +230,7 @@ def prepare_inputs( data_batches.append(data_batch) return data_batches - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None if not self._is_fdg_enabled(): @@ -277,7 +277,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0]) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 9dc83a7f1dcc..a6f2e76dc337 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -20,7 +20,7 @@ from typing_extensions import Self from ..configuration_utils import ConfigMixin -from ..utils import PushToHubMixin, get_logger +from ..utils import BaseOutput, PushToHubMixin, get_logger if TYPE_CHECKING: @@ -284,6 +284,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) +class GuiderOutput(BaseOutput): + pred: torch.Tensor + pred_cond: Optional[torch.Tensor] + pred_uncond: Optional[torch.Tensor] + + def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 1b2256732ffc..e294e8d0db59 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -21,7 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook from ..utils import get_logger -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -197,7 +197,7 @@ def forward( pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, pred_cond_skip: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> GuiderOutput: pred = None if not self._is_cfg_enabled() and not self._is_slg_enabled(): @@ -219,7 +219,7 @@ def forward( if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 68a657960a45..3530df8b0a18 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -20,7 +20,7 @@ from ..configuration_utils import register_to_config from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -192,7 +192,7 @@ def forward( pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, pred_cond_skip: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> GuiderOutput: pred = None if not self._is_cfg_enabled() and not self._is_slg_enabled(): @@ -214,7 +214,7 @@ def forward( if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index d8e8a3cf2fa8..767d20b62f85 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -20,7 +20,7 @@ from ..configuration_utils import register_to_config from ..hooks import HookRegistry from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -181,7 +181,7 @@ def forward( pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, pred_cond_seg: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> GuiderOutput: pred = None if not self._is_cfg_enabled() and not self._is_seg_enabled(): @@ -203,7 +203,7 @@ def forward( if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index b3187e526316..df1e69fe71f5 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -18,7 +18,7 @@ import torch from ..configuration_utils import register_to_config -from .guider_utils import BaseGuidance, rescale_noise_cfg +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg if TYPE_CHECKING: @@ -78,7 +78,7 @@ def prepare_inputs( data_batches.append(data_batch) return data_batches - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None if not self._is_tcfg_enabled(): @@ -89,7 +89,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred, {} + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) @property def is_conditional(self) -> bool: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 96df9711cc62..34e07dff8ab8 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -238,7 +238,7 @@ def __call__( components.guider.cleanup_models(components.unet) # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + block_state.noise_pred = components.guider(guider_state)[0] return components, block_state @@ -433,7 +433,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl components.guider.cleanup_models(components.unet) # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + block_state.noise_pred = components.guider(guider_state)[0] return components, block_state @@ -492,7 +492,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl t, block_state.latents, **block_state.extra_step_kwargs, - **block_state.scheduler_step_kwargs, return_dict=False, )[0] @@ -590,7 +589,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl t, block_state.latents, **block_state.extra_step_kwargs, - **block_state.scheduler_step_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 9871d4ad618c..34297bcfb589 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -127,7 +127,7 @@ def __call__( components.guider.cleanup_models(components.transformer) # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + block_state.noise_pred = components.guider(guider_state)[0] return components, block_state @@ -171,7 +171,6 @@ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: i block_state.noise_pred.float(), t, block_state.latents.float(), - **block_state.scheduler_step_kwargs, return_dict=False, )[0]