Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import T5Tokenizer, UMT5EncoderModel

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
Expand Down Expand Up @@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):

_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
]

def __init__(
self,
Expand Down Expand Up @@ -159,12 +164,19 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
Expand Down Expand Up @@ -387,6 +399,14 @@ def upcast_vae(self):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)

@property
def guidance_scale(self):
return self._guidance_scale

@property
def num_timesteps(self):
return self._num_timesteps

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -408,6 +428,10 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -462,6 +486,15 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.

Examples:
Expand All @@ -483,8 +516,11 @@ def __call__(
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

self._guidance_scale = guidance_scale

# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -541,6 +577,7 @@ def __call__(

# 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand All @@ -567,6 +604,15 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
Expand Down
34 changes: 33 additions & 1 deletion src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import math
import re
import urllib.parse as ul
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoTokenizer

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
Expand Down Expand Up @@ -174,6 +175,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):

_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
]

def __init__(
self,
Expand Down Expand Up @@ -395,12 +400,20 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)

if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
Expand Down Expand Up @@ -644,6 +657,10 @@ def __call__(
max_sequence_length: int = 256,
scaling_watershed: Optional[float] = 1.0,
proportional_attn: Optional[bool] = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -735,7 +752,11 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

self._guidance_scale = guidance_scale

cross_attention_kwargs = {}

# 2. Define call parameters
Expand Down Expand Up @@ -797,6 +818,8 @@ def __call__(
latents,
)

self._num_timesteps = len(timesteps)

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand Down Expand Up @@ -886,6 +909,15 @@ def __call__(

progress_bar.update()

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

if XLA_AVAILABLE:
xm.mark_step()

Expand Down