Skip to content

Commit 2c2ec8f

Browse files
committed
Comments, a bit refactor
1 parent 79e35bd commit 2c2ec8f

File tree

2 files changed

+98
-71
lines changed

2 files changed

+98
-71
lines changed

invokeai/backend/stable_diffusion/denoise_context.py

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
99

1010
if TYPE_CHECKING:
11-
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
11+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
1212

1313

1414
@dataclass
@@ -31,92 +31,101 @@ class UNetKwargs:
3131

3232
@dataclass
3333
class DenoiseInputs:
34-
"""Initial variables passed to denoise. Supposed to be unchanged.
35-
36-
Variables:
37-
orig_latents: The latent-space image to denoise.
38-
Shape: [batch, channels, latent_height, latent_width]
39-
- If we are inpainting, this is the initial latent image before noise has been added.
40-
- If we are generating a new image, this should be initialized to zeros.
41-
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
42-
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
43-
conditioning_data: Text conditionging data.
44-
noise: Noise used for two purposes:
45-
Shape: [1 or batch, channels, latent_height, latent_width]
46-
1. Used by the scheduler to noise the initial `latents` before denoising.
47-
2. Used to noise the `masked_latents` when inpainting.
48-
`noise` should be None if the `latents` tensor has already been noised.
49-
seed: The seed used to generate the noise for the denoising process.
50-
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
51-
same noise used earlier in the pipeline. This should really be handled in a clearer way.
52-
timesteps: The timestep schedule for the denoising process.
53-
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
54-
should be populated if you want noise applied *even* if timesteps is empty.
55-
attention_processor_cls: Class of attention processor that is used.
56-
"""
34+
"""Initial variables passed to denoise. Supposed to be unchanged."""
5735

36+
# The latent-space image to denoise.
37+
# Shape: [batch, channels, latent_height, latent_width]
38+
# - If we are inpainting, this is the initial latent image before noise has been added.
39+
# - If we are generating a new image, this should be initialized to zeros.
40+
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
5841
orig_latents: torch.Tensor
42+
43+
# kwargs forwarded to the scheduler.step() method.
5944
scheduler_step_kwargs: dict[str, Any]
45+
46+
# Text conditionging data.
6047
conditioning_data: TextConditioningData
48+
49+
# Noise used for two purposes:
50+
# 1. Used by the scheduler to noise the initial `latents` before denoising.
51+
# 2. Used to noise the `masked_latents` when inpainting.
52+
# `noise` should be None if the `latents` tensor has already been noised.
53+
# Shape: [1 or batch, channels, latent_height, latent_width]
6154
noise: Optional[torch.Tensor]
55+
56+
# The seed used to generate the noise for the denoising process.
57+
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
58+
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
6259
seed: int
60+
61+
# The timestep schedule for the denoising process.
6362
timesteps: torch.Tensor
63+
64+
# The first timestep in the schedule. This is used to determine the initial noise level, so
65+
# should be populated if you want noise applied *even* if timesteps is empty.
6466
init_timestep: torch.Tensor
67+
68+
# Class of attention processor that is used.
6569
attention_processor_cls: Type[Any]
6670

6771

6872
@dataclass
6973
class DenoiseContext:
70-
"""Context with all variables in denoise
71-
72-
Variables:
73-
inputs: Initial variables passed to denoise. Supposed to be unchanged.
74-
scheduler: Scheduler which used to apply noise predictions.
75-
unet: UNet model.
76-
latents: Current state of latent-space image in denoising process.
77-
None until `pre_denoise_loop` callback.
78-
Shape: [batch, channels, latent_height, latent_width]
79-
step_index: Current denoising step index.
80-
None until `pre_step` callback.
81-
timestep: Current denoising step timestep.
82-
None until `pre_step` callback.
83-
unet_kwargs: Arguments which will be passed to U Net model.
84-
Available in `pre_unet`/`post_unet` callbacks, otherwice will be None.
85-
step_output: SchedulerOutput class returned from step function(normally, generated by scheduler).
86-
Supposed to be used only in `post_step` callback, otherwice can be None.
87-
latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization.
88-
Available in events inside step(between `pre_step` and `post_stop`).
89-
Shape: [batch, channels, latent_height, latent_width]
90-
conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned.
91-
Available in `pre_unet`/`post_unet` callbacks, otherwice will be None.
92-
Can be "negative", "positive" or "both"
93-
negative_noise_pred: [TMP] Noise predictions from negative conditioning.
94-
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
95-
Shape: [batch, channels, latent_height, latent_width]
96-
positive_noise_pred: [TMP] Noise predictions from positive conditioning.
97-
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
98-
Shape: [batch, channels, latent_height, latent_width]
99-
noise_pred: Combined noise prediction from passed conditionings.
100-
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
101-
Shape: [batch, channels, latent_height, latent_width]
102-
extra: Dictionary for extensions to pass extra info about denoise process to other extensions.
103-
"""
74+
"""Context with all variables in denoise"""
10475

76+
# Initial variables passed to denoise. Supposed to be unchanged.
10577
inputs: DenoiseInputs
10678

79+
# Scheduler which used to apply noise predictions.
10780
scheduler: SchedulerMixin
81+
82+
# UNet model.
10883
unet: Optional[UNet2DConditionModel] = None
10984

85+
# Current state of latent-space image in denoising process.
86+
# None until `pre_denoise_loop` callback.
87+
# Shape: [batch, channels, latent_height, latent_width]
11088
latents: Optional[torch.Tensor] = None
89+
90+
# Current denoising step index.
91+
# None until `pre_step` callback.
11192
step_index: Optional[int] = None
93+
94+
# Current denoising step timestep.
95+
# None until `pre_step` callback.
11296
timestep: Optional[torch.Tensor] = None
97+
98+
# Arguments which will be passed to UNet model.
99+
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
113100
unet_kwargs: Optional[UNetKwargs] = None
101+
102+
# SchedulerOutput class returned from step function(normally, generated by scheduler).
103+
# Supposed to be used only in `post_step` callback, otherwise can be None.
114104
step_output: Optional[SchedulerOutput] = None
115105

106+
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
107+
# Available in events inside step(between `pre_step` and `post_stop`).
108+
# Shape: [batch, channels, latent_height, latent_width]
116109
latent_model_input: Optional[torch.Tensor] = None
117-
conditioning_mode: Optional[str] = None
110+
111+
# [TMP] Defines on which conditionings current unet call will be runned.
112+
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
113+
conditioning_mode: Optional[ConditioningMode] = None
114+
115+
# [TMP] Noise predictions from negative conditioning.
116+
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
117+
# Shape: [batch, channels, latent_height, latent_width]
118118
negative_noise_pred: Optional[torch.Tensor] = None
119+
120+
# [TMP] Noise predictions from positive conditioning.
121+
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
122+
# Shape: [batch, channels, latent_height, latent_width]
119123
positive_noise_pred: Optional[torch.Tensor] = None
124+
125+
# Combined noise prediction from passed conditionings.
126+
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
127+
# Shape: [batch, channels, latent_height, latent_width]
120128
noise_pred: Optional[torch.Tensor] = None
121129

130+
# Dictionary for extensions to pass extra info about denoise process to other extensions.
122131
extra: dict = field(default_factory=dict)

invokeai/backend/stable_diffusion/diffusion/conditioning_data.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ def is_sdxl(self):
137137
return isinstance(self.cond_text, SDXLConditioningInfo)
138138

139139
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
140+
"""Fills unet arguments with data from provided conditionings.
141+
142+
Args:
143+
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
144+
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
145+
"""
140146
_, _, h, w = unet_kwargs.sample.shape
141147
device = unet_kwargs.sample.device
142148
dtype = unet_kwargs.sample.dtype
@@ -187,16 +193,21 @@ def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: Conditionin
187193
)
188194

189195
@staticmethod
190-
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
196+
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
191197
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
192198

193199
@classmethod
194200
def _pad_conditioning(
195201
cls,
196202
cond: torch.Tensor,
197203
target_len: int,
198-
encoder_attention_mask: Optional[torch.Tensor],
199-
):
204+
) -> Tuple[torch.Tensor, torch.Tensor]:
205+
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
206+
207+
Args:
208+
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
209+
target_len (int): To which length(tokens count) pad tensor.
210+
"""
200211
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
201212

202213
if cond.shape[1] < target_len:
@@ -212,21 +223,28 @@ def _pad_conditioning(
212223
dim=1,
213224
)
214225

215-
if encoder_attention_mask is None:
216-
encoder_attention_mask = conditioning_attention_mask
217-
else:
218-
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
219-
220-
return cond, encoder_attention_mask
226+
return cond, conditioning_attention_mask
221227

222228
@classmethod
223-
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
229+
def _concat_conditionings_for_batch(
230+
cls,
231+
conditionings: List[torch.Tensor],
232+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
233+
"""Concatenate provided conditioning tensors to one batched tensor.
234+
If tensors have different sizes then pad them by zeros and creates
235+
encoder_attention_mask to exclude padding from attention.
236+
237+
Args:
238+
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
239+
"""
224240
encoder_attention_mask = None
225241
max_len = max([c.shape[1] for c in conditionings])
226242
if any(c.shape[1] != max_len for c in conditionings):
243+
encoder_attention_masks = [None] * len(conditionings)
227244
for i in range(len(conditionings)):
228-
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
229-
conditionings[i], max_len, encoder_attention_mask
245+
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(
246+
conditionings[i], max_len
230247
)
248+
encoder_attention_mask = torch.cat(encoder_attention_masks)
231249

232250
return torch.cat(conditionings), encoder_attention_mask

0 commit comments

Comments
 (0)