Skip to content

Commit bf7590b

Browse files
committed
update
1 parent bb1d9a8 commit bf7590b

10 files changed

+55
-27
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -92,7 +92,10 @@ def prepare_inputs(
9292
data_batches.append(data_batch)
9393
return data_batches
9494

95-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
95+
def forward(
96+
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
97+
) -> Tuple[torch.Tensor, GuiderInput]:
98+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
9699
pred = None
97100

98101
if not self._is_apg_enabled():
@@ -111,7 +114,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
111114
if self.guidance_rescale > 0.0:
112115
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
113116

114-
return pred, {}
117+
return pred, guider_inputs
115118

116119
@property
117120
def is_conditional(self) -> bool:

src/diffusers/guiders/auto_guidance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23-
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -145,7 +145,10 @@ def prepare_inputs(
145145
data_batches.append(data_batch)
146146
return data_batches
147147

148-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
148+
def forward(
149+
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
150+
) -> Tuple[torch.Tensor, GuiderInput]:
151+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
149152
pred = None
150153

151154
if not self._is_ag_enabled():
@@ -158,7 +161,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
158161
if self.guidance_rescale > 0.0:
159162
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
160163

161-
return pred, {}
164+
return pred, guider_inputs
162165

163166
@property
164167
def is_conditional(self) -> bool:

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -96,7 +96,10 @@ def prepare_inputs(
9696
data_batches.append(data_batch)
9797
return data_batches
9898

99-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
99+
def forward(
100+
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
101+
) -> Tuple[torch.Tensor, GuiderInput]:
102+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
100103
pred = None
101104

102105
if not self._is_cfg_enabled():
@@ -109,7 +112,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
109112
if self.guidance_rescale > 0.0:
110113
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
111114

112-
return pred, {}
115+
return pred, guider_inputs
113116

114117
@property
115118
def is_conditional(self) -> bool:

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -89,7 +89,10 @@ def prepare_inputs(
8989
data_batches.append(data_batch)
9090
return data_batches
9191

92-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
92+
def forward(
93+
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
94+
) -> Tuple[torch.Tensor, GuiderInput]:
95+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
9396
pred = None
9497

9598
if self._step < self.zero_init_steps:
@@ -109,7 +112,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
109112
if self.guidance_rescale > 0.0:
110113
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
111114

112-
return pred, {}
115+
return pred, guider_inputs
113116

114117
@property
115118
def is_conditional(self) -> bool:

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ..configuration_utils import register_to_config
2121
from ..utils import is_kornia_available
22-
from .guider_utils import BaseGuidance, rescale_noise_cfg
22+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2323

2424

2525
if TYPE_CHECKING:
@@ -230,7 +230,10 @@ def prepare_inputs(
230230
data_batches.append(data_batch)
231231
return data_batches
232232

233-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
233+
def forward(
234+
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
235+
) -> Tuple[torch.Tensor, GuiderInput]:
236+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
234237
pred = None
235238

236239
if not self._is_fdg_enabled():
@@ -277,7 +280,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
277280
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
278281
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
279282

280-
return pred, {}
283+
return pred, guider_inputs
281284

282285
@property
283286
def is_conditional(self) -> bool:

src/diffusers/guiders/guider_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
from dataclasses import dataclass
1617
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1718

1819
import torch
@@ -284,6 +285,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
284285
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
285286

286287

288+
@dataclass
289+
class GuiderInput:
290+
pred_cond: Optional[torch.Tensor]
291+
pred_uncond: Optional[torch.Tensor]
292+
293+
287294
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
288295
r"""
289296
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
2323
from ..utils import get_logger
24-
from .guider_utils import BaseGuidance, rescale_noise_cfg
24+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2525

2626

2727
if TYPE_CHECKING:
@@ -197,7 +197,8 @@ def forward(
197197
pred_cond: torch.Tensor,
198198
pred_uncond: Optional[torch.Tensor] = None,
199199
pred_cond_skip: Optional[torch.Tensor] = None,
200-
) -> torch.Tensor:
200+
) -> Tuple[torch.Tensor, GuiderInput]:
201+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
201202
pred = None
202203

203204
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -219,7 +220,7 @@ def forward(
219220
if self.guidance_rescale > 0.0:
220221
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
221222

222-
return pred, {}
223+
return pred, guider_inputs
223224

224225
@property
225226
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23-
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -192,7 +192,8 @@ def forward(
192192
pred_cond: torch.Tensor,
193193
pred_uncond: Optional[torch.Tensor] = None,
194194
pred_cond_skip: Optional[torch.Tensor] = None,
195-
) -> torch.Tensor:
195+
) -> Tuple[torch.Tensor, GuiderInput]:
196+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
196197
pred = None
197198

198199
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -214,7 +215,7 @@ def forward(
214215
if self.guidance_rescale > 0.0:
215216
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
216217

217-
return pred, {}
218+
return pred, guider_inputs
218219

219220
@property
220221
def is_conditional(self) -> bool:

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry
2222
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
23-
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -181,7 +181,8 @@ def forward(
181181
pred_cond: torch.Tensor,
182182
pred_uncond: Optional[torch.Tensor] = None,
183183
pred_cond_seg: Optional[torch.Tensor] = None,
184-
) -> torch.Tensor:
184+
) -> Tuple[torch.Tensor, GuiderInput]:
185+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
185186
pred = None
186187

187188
if not self._is_cfg_enabled() and not self._is_seg_enabled():
@@ -203,7 +204,7 @@ def forward(
203204
if self.guidance_rescale > 0.0:
204205
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
205206

206-
return pred, {}
207+
return pred, guider_inputs
207208

208209
@property
209210
def is_conditional(self) -> bool:

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -78,7 +78,10 @@ def prepare_inputs(
7878
data_batches.append(data_batch)
7979
return data_batches
8080

81-
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
81+
def forward(
82+
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
83+
) -> Tuple[torch.Tensor, GuiderInput]:
84+
guider_inputs = GuiderInput(pred_cond, pred_uncond)
8285
pred = None
8386

8487
if not self._is_tcfg_enabled():
@@ -89,7 +92,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
8992
if self.guidance_rescale > 0.0:
9093
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
9194

92-
return pred, {}
95+
return pred, guider_inputs
9396

9497
@property
9598
def is_conditional(self) -> bool:

0 commit comments

Comments
 (0)