Skip to content

Commit cf93da2

Browse files
committed
Add support for callback_on_step_end for
AuraFlowPipeline and LuminaText2ImgPipeline.
1 parent 464374f commit cf93da2

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15-
from typing import List, Optional, Tuple, Union
15+
from typing import List, Optional, Tuple, Union, Dict, Callable
1616

1717
import torch
1818
from transformers import T5Tokenizer, UMT5EncoderModel
@@ -159,12 +159,19 @@ def check_inputs(
159159
negative_prompt_embeds=None,
160160
prompt_attention_mask=None,
161161
negative_prompt_attention_mask=None,
162+
callback_on_step_end_tensor_inputs=None,
162163
):
163164
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
164165
raise ValueError(
165166
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
166167
)
167-
168+
169+
if callback_on_step_end_tensor_inputs is not None and not all(
170+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
171+
):
172+
raise ValueError(
173+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
174+
)
168175
if prompt is not None and prompt_embeds is not None:
169176
raise ValueError(
170177
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -408,6 +415,8 @@ def __call__(
408415
max_sequence_length: int = 256,
409416
output_type: Optional[str] = "pil",
410417
return_dict: bool = True,
418+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
419+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
411420
) -> Union[ImagePipelineOutput, Tuple]:
412421
r"""
413422
Function invoked when calling the pipeline for generation.
@@ -462,6 +471,15 @@ def __call__(
462471
return_dict (`bool`, *optional*, defaults to `True`):
463472
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
464473
of a plain tuple.
474+
callback_on_step_end (`Callable`, *optional*):
475+
A function that calls at the end of each denoising steps during the inference. The function is called
476+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
477+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
478+
`callback_on_step_end_tensor_inputs`.
479+
callback_on_step_end_tensor_inputs (`List`, *optional*):
480+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
481+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
482+
`._callback_tensor_inputs` attribute of your pipeline class.
465483
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
466484
467485
Examples:
@@ -483,6 +501,7 @@ def __call__(
483501
negative_prompt_embeds,
484502
prompt_attention_mask,
485503
negative_prompt_attention_mask,
504+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
486505
)
487506

488507
# 2. Determine batch size.
@@ -567,6 +586,19 @@ def __call__(
567586
# compute the previous noisy sample x_t -> x_t-1
568587
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
569588

589+
if callback_on_step_end is not None:
590+
callback_kwargs = {}
591+
for k in callback_on_step_end_tensor_inputs:
592+
callback_kwargs[k] = locals()[k]
593+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
594+
595+
latents = callback_outputs.pop("latents", latents)
596+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
597+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
598+
negative_pooled_prompt_embeds = callback_outputs.pop(
599+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
600+
)
601+
570602
# call the callback, if provided
571603
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
572604
progress_bar.update()

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import math
1818
import re
1919
import urllib.parse as ul
20-
from typing import List, Optional, Tuple, Union
20+
from typing import List, Optional, Tuple, Union, Dict, Callable
2121

2222
import torch
2323
from transformers import AutoModel, AutoTokenizer
@@ -395,11 +395,19 @@ def check_inputs(
395395
negative_prompt_embeds=None,
396396
prompt_attention_mask=None,
397397
negative_prompt_attention_mask=None,
398+
callback_on_step_end_tensor_inputs=None,
398399
):
399400
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
400401
raise ValueError(
401402
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
402403
)
404+
405+
if callback_on_step_end_tensor_inputs is not None and not all(
406+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
407+
):
408+
raise ValueError(
409+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
410+
)
403411

404412
if prompt is not None and prompt_embeds is not None:
405413
raise ValueError(
@@ -644,6 +652,8 @@ def __call__(
644652
max_sequence_length: int = 256,
645653
scaling_watershed: Optional[float] = 1.0,
646654
proportional_attn: Optional[bool] = True,
655+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
656+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
647657
) -> Union[ImagePipelineOutput, Tuple]:
648658
"""
649659
Function invoked when calling the pipeline for generation.
@@ -734,6 +744,7 @@ def __call__(
734744
prompt_embeds=prompt_embeds,
735745
negative_prompt_embeds=negative_prompt_embeds,
736746
prompt_attention_mask=prompt_attention_mask,
747+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
737748
negative_prompt_attention_mask=negative_prompt_attention_mask,
738749
)
739750
cross_attention_kwargs = {}
@@ -886,6 +897,19 @@ def __call__(
886897

887898
progress_bar.update()
888899

900+
if callback_on_step_end is not None:
901+
callback_kwargs = {}
902+
for k in callback_on_step_end_tensor_inputs:
903+
callback_kwargs[k] = locals()[k]
904+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
905+
906+
latents = callback_outputs.pop("latents", latents)
907+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
908+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
909+
negative_pooled_prompt_embeds = callback_outputs.pop(
910+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
911+
)
912+
889913
if XLA_AVAILABLE:
890914
xm.mark_step()
891915

0 commit comments

Comments
 (0)