From cf93da2b7534ce1dad3cd7141c65afe446dc417e Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Fri, 7 Feb 2025 19:02:55 +0000 Subject: [PATCH 01/12] Add support for callback_on_step_end for AuraFlowPipeline and LuminaText2ImgPipeline. --- .../pipelines/aura_flow/pipeline_aura_flow.py | 36 +++++++++++++++++-- .../pipelines/lumina/pipeline_lumina.py | 26 +++++++++++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index a3677e6a5a39..3377e1f8aacc 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Callable import torch from transformers import T5Tokenizer, UMT5EncoderModel @@ -159,12 +159,19 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." ) - + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -408,6 +415,8 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. @@ -462,6 +471,15 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: @@ -483,6 +501,7 @@ def __call__( negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) # 2. Determine batch size. @@ -567,6 +586,19 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 133cb2c5f146..38d3802cf2f0 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -17,7 +17,7 @@ import math import re import urllib.parse as ul -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Callable import torch from transformers import AutoModel, AutoTokenizer @@ -395,11 +395,19 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -644,6 +652,8 @@ def __call__( max_sequence_length: int = 256, scaling_watershed: Optional[float] = 1.0, proportional_attn: Optional[bool] = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -734,6 +744,7 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, negative_prompt_attention_mask=negative_prompt_attention_mask, ) cross_attention_kwargs = {} @@ -886,6 +897,19 @@ def __call__( progress_bar.update() + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + if XLA_AVAILABLE: xm.mark_step() From 44914811fdbfea913ff8539b3a893592682d0eaa Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Wed, 12 Feb 2025 19:20:56 +0530 Subject: [PATCH 02/12] Apply the suggestions from code review for lumina and auraflow Co-authored-by: hlky --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 4 ---- src/diffusers/pipelines/lumina/pipeline_lumina.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 3377e1f8aacc..2dc2859e89b6 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -594,10 +594,6 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 38d3802cf2f0..b6be04bcaca7 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -905,10 +905,6 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) if XLA_AVAILABLE: xm.mark_step() From 770c12cc08f45671c7528c752349f2808c7b29d0 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Wed, 12 Feb 2025 15:19:15 +0000 Subject: [PATCH 03/12] Update missing inputs and imports. --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 7 +++++++ src/diffusers/pipelines/lumina/pipeline_lumina.py | 1 + 2 files changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 2dc2859e89b6..1ea92ad3087c 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -24,6 +24,7 @@ from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback if is_torch_xla_available(): @@ -131,6 +132,12 @@ class AuraFlowPipeline(DiffusionPipeline): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] def __init__( self, diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index b6be04bcaca7..6f6e858d6e64 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -37,6 +37,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...callbacks import MultiPipelineCallbacks, PipelineCallback if is_torch_xla_available(): From 00e1209cd1f05d2145d8abc8f599eec1357037fc Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Wed, 12 Feb 2025 15:34:29 +0000 Subject: [PATCH 04/12] Add input field. --- src/diffusers/pipelines/lumina/pipeline_lumina.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 6f6e858d6e64..8f4c2b0bb817 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -175,6 +175,13 @@ class LuminaText2ImgPipeline(DiffusionPipeline): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] + def __init__( self, From 154f34d7dbf89bdf3aa1ae23ee0230f1269a3556 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Thu, 13 Feb 2025 20:31:56 +0530 Subject: [PATCH 05/12] Apply suggestions from code review-2 Co-authored-by: hlky --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 2 -- src/diffusers/pipelines/lumina/pipeline_lumina.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 1ea92ad3087c..0cb41248182f 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -135,8 +135,6 @@ class AuraFlowPipeline(DiffusionPipeline): _callback_tensor_inputs = [ "latents", "prompt_embeds", - "add_text_embeds", - "add_time_ids", ] def __init__( diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 8f4c2b0bb817..52ce0f465289 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -178,8 +178,6 @@ class LuminaText2ImgPipeline(DiffusionPipeline): _callback_tensor_inputs = [ "latents", "prompt_embeds", - "add_text_embeds", - "add_time_ids", ] From 0976a9adaccc47a0e1b667de46fded2a283223df Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Thu, 13 Feb 2025 20:39:45 +0530 Subject: [PATCH 06/12] Apply the suggestions from review for unused imports. Co-authored-by: hlky --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 4 +++- src/diffusers/pipelines/lumina/pipeline_lumina.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 0cb41248182f..807d9455e206 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -420,7 +420,9 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[ImagePipelineOutput, Tuple]: r""" diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 52ce0f465289..aa8b12adbd9e 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -658,7 +658,9 @@ def __call__( max_sequence_length: int = 256, scaling_watershed: Optional[float] = 1.0, proportional_attn: Optional[bool] = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ) -> Union[ImagePipelineOutput, Tuple]: """ From 98548077ea21c4ceac6acd098162aa180b104aa3 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Thu, 13 Feb 2025 15:18:49 +0000 Subject: [PATCH 07/12] make style. --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 5 ++--- src/diffusers/pipelines/lumina/pipeline_lumina.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 1ea92ad3087c..037177db530b 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import List, Optional, Tuple, Union, Dict, Callable +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel @@ -24,7 +24,6 @@ from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...callbacks import MultiPipelineCallbacks, PipelineCallback if is_torch_xla_available(): @@ -172,7 +171,7 @@ def check_inputs( raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." ) - + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 8f4c2b0bb817..e82c38968d82 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -17,7 +17,7 @@ import math import re import urllib.parse as ul -from typing import List, Optional, Tuple, Union, Dict, Callable +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import AutoModel, AutoTokenizer @@ -37,7 +37,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...callbacks import MultiPipelineCallbacks, PipelineCallback if is_torch_xla_available(): @@ -409,7 +408,7 @@ def check_inputs( raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." ) - + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): From aedca12948b3346e9e9f46086924b1aa711c0627 Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 16 Feb 2025 16:48:08 +0000 Subject: [PATCH 08/12] Update pipeline_aura_flow.py --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 1b6877fd1db9..d16788472ee6 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -17,6 +17,7 @@ import torch from transformers import T5Tokenizer, UMT5EncoderModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor From 1931f41b79684cd407e5cfafde3642e4b82957bb Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 16 Feb 2025 16:48:44 +0000 Subject: [PATCH 09/12] Update pipeline_lumina.py --- src/diffusers/pipelines/lumina/pipeline_lumina.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 791612c2dab0..82b3361fa9f4 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -22,6 +22,7 @@ import torch from transformers import AutoModel, AutoTokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.embeddings import get_2d_rotary_pos_embed_lumina From 0bcc6a1f6fb968e95e8e66c8337d0d8b52c18a1d Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 16 Feb 2025 16:50:25 +0000 Subject: [PATCH 10/12] Update pipeline_lumina.py --- src/diffusers/pipelines/lumina/pipeline_lumina.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 82b3361fa9f4..5a0a688001a0 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -180,7 +180,6 @@ class LuminaText2ImgPipeline(DiffusionPipeline): "prompt_embeds", ] - def __init__( self, transformer: LuminaNextDiT2DModel, From 6b7362a0efe8660c568d76208daf8bbbfa795bd7 Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 16 Feb 2025 17:11:39 +0000 Subject: [PATCH 11/12] Update pipeline_aura_flow.py --- .../pipelines/aura_flow/pipeline_aura_flow.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index d16788472ee6..ea60e66d2db9 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -399,6 +399,14 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -511,6 +519,8 @@ def __call__( callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -567,6 +577,7 @@ def __call__( # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance From 3546916ff06bdcca2da2d1670b18411ef1d082f9 Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 16 Feb 2025 17:14:17 +0000 Subject: [PATCH 12/12] Update pipeline_lumina.py --- src/diffusers/pipelines/lumina/pipeline_lumina.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 5a0a688001a0..4f6793e17b37 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -751,9 +751,12 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) + + self._guidance_scale = guidance_scale + cross_attention_kwargs = {} # 2. Define call parameters @@ -815,6 +818,8 @@ def __call__( latents, ) + self._num_timesteps = len(timesteps) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps):