Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict, Callable

import torch
from transformers import T5Tokenizer, UMT5EncoderModel
Expand Down Expand Up @@ -159,12 +159,19 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)


if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
Expand Down Expand Up @@ -408,6 +415,8 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -462,6 +471,15 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.

Examples:
Expand All @@ -483,6 +501,7 @@ def __call__(
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

# 2. Determine batch size.
Expand Down Expand Up @@ -567,6 +586,19 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
Expand Down
26 changes: 25 additions & 1 deletion src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import math
import re
import urllib.parse as ul
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Dict, Callable

import torch
from transformers import AutoModel, AutoTokenizer
Expand Down Expand Up @@ -395,11 +395,19 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)

if prompt is not None and prompt_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -644,6 +652,8 @@ def __call__(
max_sequence_length: int = 256,
scaling_watershed: Optional[float] = 1.0,
proportional_attn: Optional[bool] = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -734,6 +744,7 @@ def __call__(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
cross_attention_kwargs = {}
Expand Down Expand Up @@ -886,6 +897,19 @@ def __call__(

progress_bar.update()

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)

if XLA_AVAILABLE:
xm.mark_step()

Expand Down