Skip to content

Commit c8a7617

Browse files
committed
update
1 parent ce642e9 commit c8a7617

21 files changed

+924
-751
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,8 +761,8 @@
761761
LayerSkipConfig,
762762
PyramidAttentionBroadcastConfig,
763763
SmoothedEnergyGuidanceConfig,
764-
apply_layer_skip,
765764
apply_faster_cache,
765+
apply_layer_skip,
766766
apply_pyramid_attention_broadcast,
767767
)
768768
from .models import (
@@ -1085,6 +1085,7 @@
10851085
StableDiffusionSAGPipeline,
10861086
StableDiffusionUpscalePipeline,
10871087
StableDiffusionXLAdapterPipeline,
1088+
StableDiffusionXLAutoPipeline,
10881089
StableDiffusionXLControlNetImg2ImgPipeline,
10891090
StableDiffusionXLControlNetInpaintPipeline,
10901091
StableDiffusionXLControlNetPAGImg2ImgPipeline,
@@ -1102,7 +1103,6 @@
11021103
StableDiffusionXLPAGInpaintPipeline,
11031104
StableDiffusionXLPAGPipeline,
11041105
StableDiffusionXLPipeline,
1105-
StableDiffusionXLAutoPipeline,
11061106
StableUnCLIPImg2ImgPipeline,
11071107
StableUnCLIPPipeline,
11081108
StableVideoDiffusionPipeline,

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import TYPE_CHECKING, List, Optional
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg
2121

22+
2223
if TYPE_CHECKING:
2324
from ..pipelines.modular_pipeline import BlockState
2425

@@ -119,19 +120,19 @@ def num_conditions(self) -> int:
119120
def _is_apg_enabled(self) -> bool:
120121
if not self._enabled:
121122
return False
122-
123+
123124
is_within_range = True
124125
if self._num_inference_steps is not None:
125126
skip_start_step = int(self._start * self._num_inference_steps)
126127
skip_stop_step = int(self._stop * self._num_inference_steps)
127128
is_within_range = skip_start_step <= self._step < skip_stop_step
128-
129+
129130
is_close = False
130131
if self.use_original_formulation:
131132
is_close = math.isclose(self.guidance_scale, 0.0)
132133
else:
133134
is_close = math.isclose(self.guidance_scale, 1.0)
134-
135+
135136
return is_within_range and not is_close
136137

137138

@@ -156,25 +157,25 @@ def normalized_guidance(
156157
):
157158
diff = pred_cond - pred_uncond
158159
dim = [-i for i in range(1, len(diff.shape))]
159-
160+
160161
if momentum_buffer is not None:
161162
momentum_buffer.update(diff)
162163
diff = momentum_buffer.running_average
163-
164+
164165
if norm_threshold > 0:
165166
ones = torch.ones_like(diff)
166167
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
167168
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
168169
diff = diff * scale_factor
169-
170+
170171
v0, v1 = diff.double(), pred_cond.double()
171172
v1 = torch.nn.functional.normalize(v1, dim=dim)
172173
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
173174
v0_orthogonal = v0 - v0_parallel
174175
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
175176
normalized_update = diff_orthogonal + eta * diff_parallel
176-
177+
177178
pred = pred_cond if use_original_formulation else pred_uncond
178179
pred = pred + guidance_scale * normalized_update
179-
180+
180181
return pred

src/diffusers/guiders/auto_guidance.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional, Union, TYPE_CHECKING
16+
from typing import TYPE_CHECKING, List, Optional, Union
1717

1818
import torch
1919

2020
from ..hooks import HookRegistry, LayerSkipConfig
2121
from ..hooks.layer_skip import _apply_layer_skip_hook
2222
from .guider_utils import BaseGuidance, rescale_noise_cfg
2323

24+
2425
if TYPE_CHECKING:
2526
from ..pipelines.modular_pipeline import BlockState
2627

@@ -113,13 +114,13 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None:
113114
if self._is_ag_enabled() and self.is_unconditional:
114115
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
115116
_apply_layer_skip_hook(denoiser, config, name=name)
116-
117+
117118
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
118119
if self._is_ag_enabled() and self.is_unconditional:
119120
for name in self._auto_guidance_hook_names:
120121
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
121122
registry.remove_hook(name, recurse=True)
122-
123+
123124
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
124125
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
125126
data_batches = []
@@ -140,9 +141,9 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
140141

141142
if self.guidance_rescale > 0.0:
142143
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
143-
144+
144145
return pred, {}
145-
146+
146147
@property
147148
def is_conditional(self) -> bool:
148149
return self._count_prepared == 1
@@ -157,17 +158,17 @@ def num_conditions(self) -> int:
157158
def _is_ag_enabled(self) -> bool:
158159
if not self._enabled:
159160
return False
160-
161+
161162
is_within_range = True
162163
if self._num_inference_steps is not None:
163164
skip_start_step = int(self._start * self._num_inference_steps)
164165
skip_stop_step = int(self._stop * self._num_inference_steps)
165166
is_within_range = skip_start_step <= self._step < skip_stop_step
166-
167+
167168
is_close = False
168169
if self.use_original_formulation:
169170
is_close = math.isclose(self.guidance_scale, 0.0)
170171
else:
171172
is_close = math.isclose(self.guidance_scale, 1.0)
172-
173+
173174
return is_within_range and not is_close

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import TYPE_CHECKING, List, Optional
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg
2121

22+
2223
if TYPE_CHECKING:
2324
from ..pipelines.modular_pipeline import BlockState
2425

@@ -74,7 +75,7 @@ def __init__(
7475
self.guidance_scale = guidance_scale
7576
self.guidance_rescale = guidance_rescale
7677
self.use_original_formulation = use_original_formulation
77-
78+
7879
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
7980
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8081
data_batches = []
@@ -112,17 +113,17 @@ def num_conditions(self) -> int:
112113
def _is_cfg_enabled(self) -> bool:
113114
if not self._enabled:
114115
return False
115-
116+
116117
is_within_range = True
117118
if self._num_inference_steps is not None:
118119
skip_start_step = int(self._start * self._num_inference_steps)
119120
skip_stop_step = int(self._stop * self._num_inference_steps)
120121
is_within_range = skip_start_step <= self._step < skip_stop_step
121-
122+
122123
is_close = False
123124
if self.use_original_formulation:
124125
is_close = math.isclose(self.guidance_scale, 0.0)
125126
else:
126127
is_close = math.isclose(self.guidance_scale, 1.0)
127-
128+
128129
return is_within_range and not is_close

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING
16+
from typing import TYPE_CHECKING, List, Optional
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg
2121

22+
2223
if TYPE_CHECKING:
2324
from ..pipelines.modular_pipeline import BlockState
2425

@@ -72,7 +73,7 @@ def __init__(
7273
self.zero_init_steps = zero_init_steps
7374
self.guidance_rescale = guidance_rescale
7475
self.use_original_formulation = use_original_formulation
75-
76+
7677
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
7778
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
7879
data_batches = []
@@ -102,7 +103,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
102103
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
103104

104105
return pred, {}
105-
106+
106107
@property
107108
def is_conditional(self) -> bool:
108109
return self._count_prepared == 1
@@ -117,19 +118,19 @@ def num_conditions(self) -> int:
117118
def _is_cfg_enabled(self) -> bool:
118119
if not self._enabled:
119120
return False
120-
121+
121122
is_within_range = True
122123
if self._num_inference_steps is not None:
123124
skip_start_step = int(self._start * self._num_inference_steps)
124125
skip_stop_step = int(self._stop * self._num_inference_steps)
125126
is_within_range = skip_start_step <= self._step < skip_stop_step
126-
127+
127128
is_close = False
128129
if self.use_original_formulation:
129130
is_close = math.isclose(self.guidance_scale, 0.0)
130131
else:
131132
is_close = math.isclose(self.guidance_scale, 1.0)
132-
133+
133134
return is_within_range and not is_close
134135

135136

src/diffusers/guiders/guider_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
5858

5959
def disable(self):
6060
self._enabled = False
61-
61+
6262
def enable(self):
6363
self._enabled = True
64-
64+
6565
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
6666
self._step = step
6767
self._num_inference_steps = num_inference_steps
@@ -104,22 +104,22 @@ def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) ->
104104
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
105105
)
106106
self._input_fields = kwargs
107-
107+
108108
def prepare_models(self, denoiser: torch.nn.Module) -> None:
109109
"""
110110
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
111111
subclasses to implement specific model preparation logic.
112112
"""
113113
self._count_prepared += 1
114-
114+
115115
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
116116
"""
117117
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
118118
subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
119119
modifications made during `prepare_models`.
120120
"""
121121
pass
122-
122+
123123
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
124124
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
125125

@@ -139,15 +139,15 @@ def forward(self, *args, **kwargs) -> Any:
139139
@property
140140
def is_conditional(self) -> bool:
141141
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
142-
142+
143143
@property
144144
def is_unconditional(self) -> bool:
145145
return not self.is_conditional
146-
146+
147147
@property
148148
def num_conditions(self) -> int:
149149
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
150-
150+
151151
@classmethod
152152
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
153153
"""

0 commit comments

Comments
 (0)