Skip to content

Commit 4664356

Browse files
committed
refactor
1 parent 77324c4 commit 4664356

File tree

6 files changed

+98
-23
lines changed

6 files changed

+98
-23
lines changed

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin):
4242
we use the diffusers-native implementation that has been in the codebase for a long time.
4343
"""
4444

45+
_input_predictions = ["pred_cond", "pred_uncond"]
46+
4547
def __init__(
4648
self,
4749
guidance_scale: float = 7.5,
@@ -51,6 +53,8 @@ def __init__(
5153
guidance_rescale: float = 0.0,
5254
use_original_formulation: bool = False,
5355
):
56+
super().__init__()
57+
5458
self.guidance_scale = guidance_scale
5559
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
5660
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
@@ -68,7 +72,7 @@ def prepare_inputs(self, *args):
6872
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
6973
pred = None
7074

71-
if math.isclose(self.guidance_scale, 1.0):
75+
if self._is_cfg_enabled():
7276
pred = pred_cond
7377
else:
7478
pred = normalized_guidance(
@@ -89,10 +93,16 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
8993
@property
9094
def num_conditions(self) -> int:
9195
num_conditions = 1
92-
if not math.isclose(self.guidance_scale, 1.0):
96+
if self._is_cfg_enabled():
9397
num_conditions += 1
9498
return num_conditions
9599

100+
def _is_cfg_enabled(self) -> bool:
101+
if self.use_original_formulation:
102+
return not math.isclose(self.guidance_scale, 0.0)
103+
else:
104+
return not math.isclose(self.guidance_scale, 1.0)
105+
96106

97107
class MomentumBuffer:
98108
def __init__(self, momentum: float):

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,21 @@ class ClassifierFreeGuidance(GuidanceMixin):
5656
we use the diffusers-native implementation that has been in the codebase for a long time.
5757
"""
5858

59+
_input_predictions = ["pred_cond", "pred_uncond"]
60+
5961
def __init__(
6062
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
6163
):
64+
super().__init__()
65+
6266
self.guidance_scale = guidance_scale
6367
self.guidance_rescale = guidance_rescale
6468
self.use_original_formulation = use_original_formulation
6569

6670
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
6771
pred = None
6872

69-
if math.isclose(self.guidance_scale, 1.0):
73+
if not self._is_cfg_enabled():
7074
pred = pred_cond
7175
else:
7276
shift = pred_cond - pred_uncond
@@ -81,6 +85,12 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
8185
@property
8286
def num_conditions(self) -> int:
8387
num_conditions = 1
84-
if not math.isclose(self.guidance_scale, 1.0):
88+
if self._is_cfg_enabled():
8589
num_conditions += 1
8690
return num_conditions
91+
92+
def _is_cfg_enabled(self) -> bool:
93+
if self.use_original_formulation:
94+
return not math.isclose(self.guidance_scale, 0.0)
95+
else:
96+
return not math.isclose(self.guidance_scale, 1.0)

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,17 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin):
4848
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
4949
"""
5050

51+
_input_predictions = ["pred_cond", "pred_uncond"]
52+
5153
def __init__(
5254
self,
5355
guidance_scale: float = 7.5,
5456
zero_init_steps: int = 1,
5557
guidance_rescale: float = 0.0,
5658
use_original_formulation: bool = False,
5759
):
60+
super().__init__()
61+
5862
self.guidance_scale = guidance_scale
5963
self.zero_init_steps = zero_init_steps
6064
self.guidance_rescale = guidance_rescale
@@ -65,7 +69,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
6569

6670
if self._step < self.zero_init_steps:
6771
pred = torch.zeros_like(pred_cond)
68-
elif math.isclose(self.guidance_scale, 1.0):
72+
elif self._is_cfg_enabled():
6973
pred = pred_cond
7074
else:
7175
shift = pred_cond - pred_uncond
@@ -85,10 +89,16 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
8589
@property
8690
def num_conditions(self) -> int:
8791
num_conditions = 1
88-
if not math.isclose(self.guidance_scale, 1.0):
92+
if self._is_cfg_enabled():
8993
num_conditions += 1
9094
return num_conditions
9195

96+
def _is_cfg_enabled(self) -> bool:
97+
if self.use_original_formulation:
98+
return not math.isclose(self.guidance_scale, 0.0)
99+
else:
100+
return not math.isclose(self.guidance_scale, 1.0)
101+
92102

93103
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
94104
cond = cond.float()

src/diffusers/guiders/guider_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List, Optional, Tuple, Union
15+
from typing import Any, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818

@@ -25,15 +25,26 @@
2525
class GuidanceMixin:
2626
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
2727

28+
_input_predictions = None
29+
2830
def __init__(self):
2931
self._step: int = None
3032
self._num_inference_steps: int = None
3133
self._timestep: torch.LongTensor = None
34+
self._preds: Dict[str, torch.Tensor] = {}
35+
self._num_outputs_prepared: int = 0
36+
37+
if self._input_predictions is None or not isinstance(self._input_predictions, list):
38+
raise ValueError(
39+
"`_input_predictions` must be a list of required prediction names for the guidance technique."
40+
)
3241

3342
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
3443
self._step = step
3544
self._num_inference_steps = num_inference_steps
3645
self._timestep = timestep
46+
self._preds = {}
47+
self._num_outputs_prepared = 0
3748

3849
def prepare_models(self, denoiser: torch.nn.Module) -> None:
3950
pass
@@ -63,15 +74,22 @@ def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]])
6374
)
6475
return tuple(list_of_inputs)
6576

77+
def prepare_outputs(self, pred: torch.Tensor) -> None:
78+
self._num_outputs_prepared += 1
79+
if self._num_outputs_prepared > self.num_conditions:
80+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
81+
key = self._input_predictions[self._num_outputs_prepared - 1]
82+
self._preds[key] = pred
83+
6684
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
6785
pass
6886

69-
def __call__(self, *args) -> Any:
70-
if len(args) != self.num_conditions:
87+
def __call__(self, **kwargs) -> Any:
88+
if len(kwargs) != self.num_conditions:
7189
raise ValueError(
72-
f"Expected {self.num_conditions} arguments, but got {len(args)}. Please provide the correct number of arguments."
90+
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
7391
)
74-
return self.forward(*args)
92+
return self.forward(**kwargs)
7593

7694
def forward(self, *args, **kwargs) -> Any:
7795
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
@@ -80,6 +98,10 @@ def forward(self, *args, **kwargs) -> Any:
8098
def num_conditions(self) -> int:
8199
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
82100

101+
@property
102+
def outputs(self) -> Dict[str, torch.Tensor]:
103+
return self._preds
104+
83105

84106
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
85107
r"""

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class SkipLayerGuidance(GuidanceMixin):
7171
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
7272
"""
7373

74+
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
75+
7476
def __init__(
7577
self,
7678
guidance_scale: float = 7.5,
@@ -82,6 +84,8 @@ def __init__(
8284
guidance_rescale: float = 0.0,
8385
use_original_formulation: bool = False,
8486
):
87+
super().__init__()
88+
8589
self.guidance_scale = guidance_scale
8690
self.skip_layer_guidance_scale = skip_layer_guidance_scale
8791
self.skip_layer_guidance_start = skip_layer_guidance_start
@@ -157,6 +161,18 @@ def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]])
157161
)
158162
return tuple(list_of_inputs)
159163

164+
def prepare_outputs(self, pred: torch.Tensor) -> None:
165+
self._num_outputs_prepared += 1
166+
if self._num_outputs_prepared > self.num_conditions:
167+
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
168+
key = self._input_predictions[self._num_outputs_prepared - 1]
169+
if not self._is_cfg_enabled() and self._is_slg_enabled():
170+
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
171+
# to avoid writing into pred_uncond which is not used
172+
if self._num_outputs_prepared == 2:
173+
key = "pred_cond_skip"
174+
self._preds[key] = pred
175+
160176
def cleanup_models(self, denoiser: torch.nn.Module):
161177
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
162178
# Remove the hooks after inference
@@ -173,16 +189,16 @@ def forward(
173189
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
174190
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
175191

176-
if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0):
192+
if not self._is_cfg_enabled() and not self._is_slg_enabled():
177193
pred = pred_cond
178-
elif math.isclose(self.guidance_scale, 1.0):
194+
elif not self._is_cfg_enabled():
179195
if skip_start_step < self._step < skip_stop_step:
180196
shift = pred_cond - pred_cond_skip
181197
pred = pred_cond if self.use_original_formulation else pred_cond_skip
182198
pred = pred + self.skip_layer_guidance_scale * shift
183199
else:
184200
pred = pred_cond
185-
elif math.isclose(self.skip_layer_guidance_scale, 1.0):
201+
elif not self._is_slg_enabled():
186202
shift = pred_cond - pred_uncond
187203
pred = pred_cond if self.use_original_formulation else pred_uncond
188204
pred = pred + self.guidance_scale * shift
@@ -203,12 +219,19 @@ def forward(
203219
@property
204220
def num_conditions(self) -> int:
205221
num_conditions = 1
206-
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
207-
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
208-
209-
if not math.isclose(self.guidance_scale, 1.0):
222+
if self._is_cfg_enabled():
210223
num_conditions += 1
211-
if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step:
224+
if self._is_slg_enabled():
212225
num_conditions += 1
213-
214226
return num_conditions
227+
228+
def _is_cfg_enabled(self) -> bool:
229+
if self.use_original_formulation:
230+
return not math.isclose(self.guidance_scale, 0.0)
231+
else:
232+
return not math.isclose(self.guidance_scale, 1.0)
233+
234+
def _is_slg_enabled(self) -> bool:
235+
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
236+
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
237+
return skip_start_step < self._step < skip_stop_step

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,6 @@ def __call__(
635635
crops_coords_top_left[0],
636636
)
637637

638-
noise_preds = []
639638
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
640639
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
641640
):
@@ -652,9 +651,10 @@ def __call__(
652651
attention_kwargs=attention_kwargs,
653652
return_dict=False,
654653
)[0]
655-
noise_preds.append(noise_pred)
654+
guidance.prepare_outputs(noise_pred)
656655

657-
noise_pred = guidance(*noise_preds)
656+
outputs = guidance.outputs
657+
noise_pred = guidance(**outputs)
658658
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
659659
guidance.cleanup_models(self.transformer)
660660

0 commit comments

Comments
 (0)