Skip to content

Commit e889530

Browse files
authored
Merge branch 'main' into deprecate-jax
2 parents f4169eb + 22b229b commit e889530

14 files changed

+401
-39
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -92,7 +92,7 @@ def prepare_inputs(
9292
data_batches.append(data_batch)
9393
return data_batches
9494

95-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
95+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
9696
pred = None
9797

9898
if not self._is_apg_enabled():
@@ -111,7 +111,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
111111
if self.guidance_rescale > 0.0:
112112
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
113113

114-
return pred, {}
114+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
115115

116116
@property
117117
def is_conditional(self) -> bool:

src/diffusers/guiders/auto_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23-
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -145,7 +145,7 @@ def prepare_inputs(
145145
data_batches.append(data_batch)
146146
return data_batches
147147

148-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
148+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
149149
pred = None
150150

151151
if not self._is_ag_enabled():
@@ -158,7 +158,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
158158
if self.guidance_rescale > 0.0:
159159
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
160160

161-
return pred, {}
161+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
162162

163163
@property
164164
def is_conditional(self) -> bool:

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -96,7 +96,7 @@ def prepare_inputs(
9696
data_batches.append(data_batch)
9797
return data_batches
9898

99-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
99+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
100100
pred = None
101101

102102
if not self._is_cfg_enabled():
@@ -109,7 +109,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
109109
if self.guidance_rescale > 0.0:
110110
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
111111

112-
return pred, {}
112+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
113113

114114
@property
115115
def is_conditional(self) -> bool:

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -89,7 +89,7 @@ def prepare_inputs(
8989
data_batches.append(data_batch)
9090
return data_batches
9191

92-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
92+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
9393
pred = None
9494

9595
if self._step < self.zero_init_steps:
@@ -109,7 +109,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
109109
if self.guidance_rescale > 0.0:
110110
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
111111

112-
return pred, {}
112+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
113113

114114
@property
115115
def is_conditional(self) -> bool:

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ..configuration_utils import register_to_config
2121
from ..utils import is_kornia_available
22-
from .guider_utils import BaseGuidance, rescale_noise_cfg
22+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2323

2424

2525
if TYPE_CHECKING:
@@ -230,7 +230,7 @@ def prepare_inputs(
230230
data_batches.append(data_batch)
231231
return data_batches
232232

233-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
233+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
234234
pred = None
235235

236236
if not self._is_fdg_enabled():
@@ -277,7 +277,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
277277
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
278278
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
279279

280-
return pred, {}
280+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
281281

282282
@property
283283
def is_conditional(self) -> bool:

src/diffusers/guiders/guider_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing_extensions import Self
2121

2222
from ..configuration_utils import ConfigMixin
23-
from ..utils import PushToHubMixin, get_logger
23+
from ..utils import BaseOutput, PushToHubMixin, get_logger
2424

2525

2626
if TYPE_CHECKING:
@@ -284,6 +284,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
284284
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
285285

286286

287+
class GuiderOutput(BaseOutput):
288+
pred: torch.Tensor
289+
pred_cond: Optional[torch.Tensor]
290+
pred_uncond: Optional[torch.Tensor]
291+
292+
287293
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
288294
r"""
289295
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
2323
from ..utils import get_logger
24-
from .guider_utils import BaseGuidance, rescale_noise_cfg
24+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2525

2626

2727
if TYPE_CHECKING:
@@ -197,7 +197,7 @@ def forward(
197197
pred_cond: torch.Tensor,
198198
pred_uncond: Optional[torch.Tensor] = None,
199199
pred_cond_skip: Optional[torch.Tensor] = None,
200-
) -> torch.Tensor:
200+
) -> GuiderOutput:
201201
pred = None
202202

203203
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -219,7 +219,7 @@ def forward(
219219
if self.guidance_rescale > 0.0:
220220
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
221221

222-
return pred, {}
222+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
223223

224224
@property
225225
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23-
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -192,7 +192,7 @@ def forward(
192192
pred_cond: torch.Tensor,
193193
pred_uncond: Optional[torch.Tensor] = None,
194194
pred_cond_skip: Optional[torch.Tensor] = None,
195-
) -> torch.Tensor:
195+
) -> GuiderOutput:
196196
pred = None
197197

198198
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -214,7 +214,7 @@ def forward(
214214
if self.guidance_rescale > 0.0:
215215
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
216216

217-
return pred, {}
217+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
218218

219219
@property
220220
def is_conditional(self) -> bool:

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry
2222
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
23-
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -181,7 +181,7 @@ def forward(
181181
pred_cond: torch.Tensor,
182182
pred_uncond: Optional[torch.Tensor] = None,
183183
pred_cond_seg: Optional[torch.Tensor] = None,
184-
) -> torch.Tensor:
184+
) -> GuiderOutput:
185185
pred = None
186186

187187
if not self._is_cfg_enabled() and not self._is_seg_enabled():
@@ -203,7 +203,7 @@ def forward(
203203
if self.guidance_rescale > 0.0:
204204
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
205205

206-
return pred, {}
206+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
207207

208208
@property
209209
def is_conditional(self) -> bool:

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -78,7 +78,7 @@ def prepare_inputs(
7878
data_batches.append(data_batch)
7979
return data_batches
8080

81-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
81+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
8282
pred = None
8383

8484
if not self._is_tcfg_enabled():
@@ -89,7 +89,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
8989
if self.guidance_rescale > 0.0:
9090
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
9191

92-
return pred, {}
92+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
9393

9494
@property
9595
def is_conditional(self) -> bool:

0 commit comments

Comments
 (0)