diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index a3677e6a5a39..ea60e66d2db9 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor @@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] def __init__( self, @@ -159,12 +164,19 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -387,6 +399,14 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -408,6 +428,10 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. @@ -462,6 +486,15 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: @@ -483,8 +516,11 @@ def __call__( negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -541,6 +577,7 @@ def __call__( # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -567,6 +604,15 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 133cb2c5f146..4f6793e17b37 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -17,11 +17,12 @@ import math import re import urllib.parse as ul -from typing import List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import AutoModel, AutoTokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.embeddings import get_2d_rotary_pos_embed_lumina @@ -174,6 +175,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] def __init__( self, @@ -395,12 +400,20 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -644,6 +657,10 @@ def __call__( max_sequence_length: int = 256, scaling_watershed: Optional[float] = 1.0, proportional_attn: Optional[bool] = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -735,7 +752,11 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) + + self._guidance_scale = guidance_scale + cross_attention_kwargs = {} # 2. Define call parameters @@ -797,6 +818,8 @@ def __call__( latents, ) + self._num_timesteps = len(timesteps) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -886,6 +909,15 @@ def __call__( progress_bar.update() + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + if XLA_AVAILABLE: xm.mark_step()