Skip to content

Commit e8768e5

Browse files
committed
apply review suggestions
1 parent 78fca12 commit e8768e5

File tree

4 files changed

+128
-67
lines changed

4 files changed

+128
-67
lines changed

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, Union, Tuple, List
16+
from typing import Optional, Union, Tuple, List, TYPE_CHECKING
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
2121

22+
if TYPE_CHECKING:
23+
from ..pipelines.modular_pipeline import BlockState
24+
2225

2326
class ClassifierFreeGuidance(BaseGuidance):
2427
"""
@@ -72,15 +75,13 @@ def __init__(
7275
self.guidance_rescale = guidance_rescale
7376
self.use_original_formulation = use_original_formulation
7477

75-
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
76-
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
77-
78-
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
79-
self._num_outputs_prepared += 1
80-
if self._num_outputs_prepared > self.num_conditions:
81-
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
82-
key = self._input_predictions[self._num_outputs_prepared - 1]
83-
self._preds[key] = pred
78+
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
79+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
80+
data_batches = []
81+
for i in range(self.num_conditions):
82+
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
83+
data_batches.append(data_batch)
84+
return data_batches
8485

8586
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
8687
pred = None
@@ -95,7 +96,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
9596
if self.guidance_rescale > 0.0:
9697
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
9798

98-
return pred
99+
return pred, {}
99100

100101
@property
101102
def is_conditional(self) -> bool:

src/diffusers/guiders/entropy_rectifying_guidance.py

Whitespace-only changes.

src/diffusers/guiders/guider_utils.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
if TYPE_CHECKING:
2323
from ..models.attention_processor import AttentionProcessor
24+
from ..pipelines.modular_pipeline import BlockState
2425

2526

2627
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -30,14 +31,15 @@ class BaseGuidance:
3031
r"""Base class providing the skeleton for implementing guidance techniques."""
3132

3233
_input_predictions = None
34+
_identifier_key = "__guidance_identifier__"
3335

3436
def __init__(self, start: float = 0.0, stop: float = 1.0):
3537
self._start = start
3638
self._stop = stop
3739
self._step: int = None
3840
self._num_inference_steps: int = None
3941
self._timestep: torch.LongTensor = None
40-
self._preds: Dict[str, torch.Tensor] = {}
42+
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
4143
self._num_outputs_prepared: int = 0
4244
self._enabled = True
4345

@@ -65,28 +67,64 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen
6567
self._step = step
6668
self._num_inference_steps = num_inference_steps
6769
self._timestep = timestep
68-
self._preds = {}
6970
self._num_outputs_prepared = 0
7071

72+
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
73+
"""
74+
Set the input fields for the guidance technique. The input fields are used to specify the names of the
75+
returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is
76+
obtained from the values of the provided keyword arguments to this method.
77+
78+
Args:
79+
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
80+
A dictionary where the keys are the names of the fields that will be used to store the data once
81+
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
82+
which is used to look up the required data provided for preparation.
83+
84+
If a string is provided, it will be used as the conditional data (or unconditional if used with
85+
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
86+
be the conditional data identifier and the second element must be the unconditional data identifier
87+
or None.
88+
89+
Example:
90+
91+
```
92+
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
93+
94+
BaseGuidance.set_input_fields(
95+
latents="latents",
96+
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
97+
)
98+
```
99+
"""
100+
for key, value in kwargs.items():
101+
is_string = isinstance(value, str)
102+
is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
103+
if not (is_string or is_tuple_of_str_with_len_2):
104+
raise ValueError(
105+
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
106+
)
107+
self._input_fields = kwargs
108+
71109
def prepare_models(self, denoiser: torch.nn.Module) -> None:
72110
"""
73111
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
74112
subclasses to implement specific model preparation logic.
75113
"""
76114
pass
77115

78-
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
116+
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
79117
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
80118

81-
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
82-
raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.")
83-
84-
def __call__(self, **kwargs) -> Any:
85-
if len(kwargs) != self.num_conditions:
119+
def __call__(self, data: List["BlockState"]) -> Any:
120+
if not all(hasattr(d, "noise_pred") for d in data):
121+
raise ValueError("Expected all data to have `noise_pred` attribute.")
122+
if len(data) != self.num_conditions:
86123
raise ValueError(
87-
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
124+
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
88125
)
89-
return self.forward(**kwargs)
126+
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
127+
return self.forward(**forward_inputs)
90128

91129
def forward(self, *args, **kwargs) -> Any:
92130
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
@@ -102,10 +140,48 @@ def is_unconditional(self) -> bool:
102140
@property
103141
def num_conditions(self) -> int:
104142
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
105-
106-
@property
107-
def outputs(self) -> Dict[str, torch.Tensor]:
108-
return self._preds, {}
143+
144+
@classmethod
145+
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
146+
"""
147+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of
148+
the `BaseGuidance` class. It prepares the batch based on the provided tuple index.
149+
150+
Args:
151+
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
152+
A dictionary where the keys are the names of the fields that will be used to store the data once
153+
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
154+
which is used to look up the required data provided for preparation.
155+
If a string is provided, it will be used as the conditional data (or unconditional if used with
156+
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
157+
be the conditional data identifier and the second element must be the unconditional data identifier
158+
or None.
159+
data (`BlockState`):
160+
The input data to be prepared.
161+
tuple_index (`int`):
162+
The index to use when accessing input fields that are tuples.
163+
164+
Returns:
165+
`BlockState`: The prepared batch of data.
166+
"""
167+
from ..pipelines.modular_pipeline import BlockState
168+
169+
if input_fields is None:
170+
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
171+
data_batch = {}
172+
for key, value in input_fields.items():
173+
try:
174+
if isinstance(value, str):
175+
data_batch[key] = getattr(data, value)
176+
elif isinstance(value, tuple):
177+
data_batch[key] = getattr(data, value[tuple_index])
178+
else:
179+
# We've already checked that value is a string or a tuple of strings with length 2
180+
pass
181+
except AttributeError:
182+
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
183+
data_batch[cls._identifier_key] = identifier
184+
return BlockState(**data_batch)
109185

110186

111187
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,64 +2239,48 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
22392239
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
22402240
data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)
22412241

2242+
pipeline.guider.set_input_fields(
2243+
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
2244+
add_time_ids=("add_time_ids", "negative_add_time_ids"),
2245+
pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
2246+
ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"),
2247+
)
2248+
22422249
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
22432250
for i, t in enumerate(data.timesteps):
22442251
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
2252+
guider_data = pipeline.guider.prepare_inputs(data)
22452253

2246-
(
2247-
latents,
2248-
prompt_embeds,
2249-
add_time_ids,
2250-
pooled_prompt_embeds,
2251-
mask,
2252-
masked_image_latents,
2253-
ip_adapter_embeds,
2254-
) = pipeline.guider.prepare_inputs(
2255-
pipeline.unet,
2256-
data.latents,
2257-
(data.prompt_embeds, data.negative_prompt_embeds),
2258-
(data.add_time_ids, data.negative_add_time_ids),
2259-
(data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds),
2260-
data.mask,
2261-
data.masked_image_latents,
2262-
(data.ip_adapter_embeds, data.negative_ip_adapter_embeds),
2263-
)
2264-
2265-
for batch_index, (
2266-
latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i,
2267-
) in enumerate(zip(
2268-
latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds
2269-
)):
2254+
data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t)
2255+
2256+
# Prepare for inpainting
2257+
if data.num_channels_unet == 9:
2258+
data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1)
2259+
2260+
for batch in guider_data:
22702261
pipeline.guider.prepare_models(pipeline.unet)
2271-
latents_i = pipeline.scheduler.scale_model_input(latents_i, t)
2272-
2273-
# Prepare for inpainting
2274-
if data.num_channels_unet == 9:
2275-
latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1)
22762262

22772263
# Prepare additional conditionings
2278-
data.added_cond_kwargs = {
2279-
"text_embeds": pooled_prompt_embeds_i,
2280-
"time_ids": add_time_ids_i,
2264+
batch.added_cond_kwargs = {
2265+
"text_embeds": batch.pooled_prompt_embeds,
2266+
"time_ids": batch.add_time_ids,
22812267
}
2282-
if ip_adapter_embeds_i is not None:
2283-
data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i
2284-
2268+
if batch.ip_adapter_embeds is not None:
2269+
batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds
2270+
22852271
# Predict the noise residual
2286-
data.noise_pred = pipeline.unet(
2287-
latents_i,
2272+
batch.noise_pred = pipeline.unet(
2273+
data.scaled_latents,
22882274
t,
2289-
encoder_hidden_states=prompt_embeds_i,
2275+
encoder_hidden_states=batch.prompt_embeds,
22902276
timestep_cond=data.timestep_cond,
22912277
cross_attention_kwargs=data.cross_attention_kwargs,
2292-
added_cond_kwargs=data.added_cond_kwargs,
2278+
added_cond_kwargs=batch.added_cond_kwargs,
22932279
return_dict=False,
22942280
)[0]
2295-
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
22962281

22972282
# Perform guidance
2298-
outputs, scheduler_step_kwargs = pipeline.guider.outputs
2299-
data.noise_pred = pipeline.guider(**outputs)
2283+
data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data)
23002284

23012285
# Perform scheduler step using the predicted output
23022286
data.latents_dtype = data.latents.dtype

0 commit comments

Comments
 (0)