Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -92,7 +92,10 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_apg_enabled():
Expand All @@ -111,7 +114,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/guiders/auto_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -145,7 +145,10 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_ag_enabled():
Expand All @@ -158,7 +161,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/guiders/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -96,7 +96,10 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_cfg_enabled():
Expand All @@ -109,7 +112,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/guiders/classifier_free_zero_star_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -89,7 +89,10 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if self._step < self.zero_init_steps:
Expand All @@ -109,7 +112,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/guiders/frequency_decoupled_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..configuration_utils import register_to_config
from ..utils import is_kornia_available
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -230,7 +230,10 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_fdg_enabled():
Expand Down Expand Up @@ -277,7 +280,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/guiders/guider_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -284,6 +285,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)


@dataclass
class GuiderInput:
pred_cond: Optional[torch.Tensor]
pred_uncond: Optional[torch.Tensor]


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/guiders/perturbed_attention_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -197,7 +197,8 @@ def forward(
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_cfg_enabled() and not self._is_slg_enabled():
Expand All @@ -219,7 +220,7 @@ def forward(
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/guiders/skip_layer_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -192,7 +192,8 @@ def forward(
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_cfg_enabled() and not self._is_slg_enabled():
Expand All @@ -214,7 +215,7 @@ def forward(
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -181,7 +181,8 @@ def forward(
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_cfg_enabled() and not self._is_seg_enabled():
Expand All @@ -203,7 +204,7 @@ def forward(
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs

@property
def is_conditional(self) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/guiders/tangential_classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, rescale_noise_cfg
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg


if TYPE_CHECKING:
Expand Down Expand Up @@ -78,7 +78,10 @@ def prepare_inputs(
data_batches.append(data_batch)
return data_batches

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, GuiderInput]:
guider_inputs = GuiderInput(pred_cond, pred_uncond)
pred = None

if not self._is_tcfg_enabled():
Expand All @@ -89,7 +92,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred, {}
return pred, guider_inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just define a GuiderOutput

and do

return GuiderOutput(pred=pred, pred_cond=pred_cond)

and in the denoiser step, we can do this for now since we don't need other variables for SDXL/wan/flux

block_state.noise_pred = components.guider(guider_state)[0]


@property
def is_conditional(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __call__(
components.guider.cleanup_models(components.unet)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
block_state.noise_pred, block_state.guider_inputs = components.guider(guider_state)

return components, block_state

Expand Down Expand Up @@ -433,7 +433,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
components.guider.cleanup_models(components.unet)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
block_state.noise_pred, block_state.guider_inputs = components.guider(guider_state)

return components, block_state

Expand Down Expand Up @@ -492,7 +492,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
**block_state.guider_inputs,
return_dict=False,
)[0]

Expand Down Expand Up @@ -590,7 +590,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
t,
block_state.latents,
**block_state.extra_step_kwargs,
**block_state.scheduler_step_kwargs,
**block_state.guider_inputs,
return_dict=False,
)[0]

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/modular_pipelines/wan/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __call__(
components.guider.cleanup_models(components.transformer)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
block_state.noise_pred, block_state.guider_inputs = components.guider(guider_state)

return components, block_state

Expand Down Expand Up @@ -171,7 +171,7 @@ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: i
block_state.noise_pred.float(),
t,
block_state.latents.float(),
**block_state.scheduler_step_kwargs,
**block_state.guider_inputs,
return_dict=False,
)[0]

Expand Down
Loading