Skip to content

Commit ac15978

Browse files
committed
apply review suggestions
1 parent 86e96c9 commit ac15978

12 files changed

+32
-52
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -92,10 +92,7 @@ def prepare_inputs(
9292
data_batches.append(data_batch)
9393
return data_batches
9494

95-
def forward(
96-
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
97-
) -> Tuple[torch.Tensor, GuiderInput]:
98-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
95+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
9996
pred = None
10097

10198
if not self._is_apg_enabled():
@@ -114,7 +111,7 @@ def forward(
114111
if self.guidance_rescale > 0.0:
115112
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
116113

117-
return pred, guider_inputs
114+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
118115

119116
@property
120117
def is_conditional(self) -> bool:

src/diffusers/guiders/auto_guidance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -145,10 +145,7 @@ def prepare_inputs(
145145
data_batches.append(data_batch)
146146
return data_batches
147147

148-
def forward(
149-
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
150-
) -> Tuple[torch.Tensor, GuiderInput]:
151-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
148+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
152149
pred = None
153150

154151
if not self._is_ag_enabled():
@@ -161,7 +158,7 @@ def forward(
161158
if self.guidance_rescale > 0.0:
162159
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
163160

164-
return pred, guider_inputs
161+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
165162

166163
@property
167164
def is_conditional(self) -> bool:

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -96,10 +96,7 @@ def prepare_inputs(
9696
data_batches.append(data_batch)
9797
return data_batches
9898

99-
def forward(
100-
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
101-
) -> Tuple[torch.Tensor, GuiderInput]:
102-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
99+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
103100
pred = None
104101

105102
if not self._is_cfg_enabled():
@@ -112,7 +109,7 @@ def forward(
112109
if self.guidance_rescale > 0.0:
113110
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
114111

115-
return pred, guider_inputs
112+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
116113

117114
@property
118115
def is_conditional(self) -> bool:

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -89,10 +89,7 @@ def prepare_inputs(
8989
data_batches.append(data_batch)
9090
return data_batches
9191

92-
def forward(
93-
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
94-
) -> Tuple[torch.Tensor, GuiderInput]:
95-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
92+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
9693
pred = None
9794

9895
if self._step < self.zero_init_steps:
@@ -112,7 +109,7 @@ def forward(
112109
if self.guidance_rescale > 0.0:
113110
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
114111

115-
return pred, guider_inputs
112+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
116113

117114
@property
118115
def is_conditional(self) -> bool:

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ..configuration_utils import register_to_config
2121
from ..utils import is_kornia_available
22-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
22+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2323

2424

2525
if TYPE_CHECKING:
@@ -230,10 +230,7 @@ def prepare_inputs(
230230
data_batches.append(data_batch)
231231
return data_batches
232232

233-
def forward(
234-
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
235-
) -> Tuple[torch.Tensor, GuiderInput]:
236-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
233+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
237234
pred = None
238235

239236
if not self._is_fdg_enabled():
@@ -280,7 +277,7 @@ def forward(
280277
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
281278
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
282279

283-
return pred, guider_inputs
280+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
284281

285282
@property
286283
def is_conditional(self) -> bool:

src/diffusers/guiders/guider_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
286286

287287

288288
@dataclass
289-
class GuiderInput:
289+
class GuiderOutput:
290+
pred: torch.Tensor
290291
pred_cond: Optional[torch.Tensor]
291292
pred_uncond: Optional[torch.Tensor]
292293

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
2323
from ..utils import get_logger
24-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
24+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2525

2626

2727
if TYPE_CHECKING:
@@ -197,8 +197,7 @@ def forward(
197197
pred_cond: torch.Tensor,
198198
pred_uncond: Optional[torch.Tensor] = None,
199199
pred_cond_skip: Optional[torch.Tensor] = None,
200-
) -> Tuple[torch.Tensor, GuiderInput]:
201-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
200+
) -> GuiderOutput:
202201
pred = None
203202

204203
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -220,7 +219,7 @@ def forward(
220219
if self.guidance_rescale > 0.0:
221220
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
222221

223-
return pred, guider_inputs
222+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
224223

225224
@property
226225
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry, LayerSkipConfig
2222
from ..hooks.layer_skip import _apply_layer_skip_hook
23-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -192,8 +192,7 @@ def forward(
192192
pred_cond: torch.Tensor,
193193
pred_uncond: Optional[torch.Tensor] = None,
194194
pred_cond_skip: Optional[torch.Tensor] = None,
195-
) -> Tuple[torch.Tensor, GuiderInput]:
196-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
195+
) -> GuiderOutput:
197196
pred = None
198197

199198
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -215,7 +214,7 @@ def forward(
215214
if self.guidance_rescale > 0.0:
216215
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
217216

218-
return pred, guider_inputs
217+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
219218

220219
@property
221220
def is_conditional(self) -> bool:

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import register_to_config
2121
from ..hooks import HookRegistry
2222
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
23-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
23+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2424

2525

2626
if TYPE_CHECKING:
@@ -181,8 +181,7 @@ def forward(
181181
pred_cond: torch.Tensor,
182182
pred_uncond: Optional[torch.Tensor] = None,
183183
pred_cond_seg: Optional[torch.Tensor] = None,
184-
) -> Tuple[torch.Tensor, GuiderInput]:
185-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
184+
) -> GuiderOutput:
186185
pred = None
187186

188187
if not self._is_cfg_enabled() and not self._is_seg_enabled():
@@ -204,7 +203,7 @@ def forward(
204203
if self.guidance_rescale > 0.0:
205204
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
206205

207-
return pred, guider_inputs
206+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
208207

209208
@property
210209
def is_conditional(self) -> bool:

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ..configuration_utils import register_to_config
21-
from .guider_utils import BaseGuidance, GuiderInput, rescale_noise_cfg
21+
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
2222

2323

2424
if TYPE_CHECKING:
@@ -78,10 +78,7 @@ def prepare_inputs(
7878
data_batches.append(data_batch)
7979
return data_batches
8080

81-
def forward(
82-
self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None
83-
) -> Tuple[torch.Tensor, GuiderInput]:
84-
guider_inputs = GuiderInput(pred_cond, pred_uncond)
81+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
8582
pred = None
8683

8784
if not self._is_tcfg_enabled():
@@ -92,7 +89,7 @@ def forward(
9289
if self.guidance_rescale > 0.0:
9390
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
9491

95-
return pred, guider_inputs
92+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
9693

9794
@property
9895
def is_conditional(self) -> bool:

0 commit comments

Comments
 (0)