Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,22 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->

return list(combined_dict.values())

@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs]

@property
def intermediate_input_names(self) -> List[str]:
return [input_param.name for input_param in self.intermediate_inputs]

@property
def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs]

@property
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]


class PipelineBlock(ModularPipelineBlocks):
"""
Expand Down Expand Up @@ -2825,3 +2841,8 @@ def _dict_to_component_spec(
type_hint=type_hint,
**spec_dict,
)

def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,6 @@ def prepare_latents_inpaint(
timestep=None,
is_strength_max=True,
add_noise=True,
return_noise=False,
return_image_latents=False,
Comment on lines -747 to -748
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason behind this change?

):
shape = (
batch_size,
Expand All @@ -768,7 +766,7 @@ def prepare_latents_inpaint(
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
elif latents is None and not is_strength_max:
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
Expand All @@ -786,13 +784,7 @@ def prepare_latents_inpaint(
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)

outputs = (latents,)

if return_noise:
outputs += (noise,)

if return_image_latents:
outputs += (image_latents,)
outputs = (latents, noise, image_latents)

return outputs

Expand Down Expand Up @@ -864,7 +856,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor

block_state.latents, block_state.noise = self.prepare_latents_inpaint(
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
components,
block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents,
Expand All @@ -878,8 +870,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise,
return_noise=True,
return_image_latents=False,
)

# 7. Prepare mask latent variables
Expand Down
57 changes: 31 additions & 26 deletions tests/pipelines/pipeline_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
]
)

TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rearranged a bit, i.e. put all the batch inputs together, image inputs together
did not delete or add anything


TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])

IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])

IMAGE_VARIATION_PARAMS = frozenset(
[
"image",
Expand All @@ -35,8 +29,6 @@
]
)

IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])

TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[
"prompt",
Expand All @@ -50,8 +42,6 @@
]
)

TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])

TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
# Text guided image variation with an image mask
Expand All @@ -67,8 +57,6 @@
]
)

TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])

IMAGE_INPAINTING_PARAMS = frozenset(
[
# image variation with an image mask
Expand All @@ -80,8 +68,6 @@
]
)

IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])

IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
"example_image",
Expand All @@ -93,20 +79,12 @@
]
)

IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])

CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])

CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])

UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])

UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])

UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])

UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])

TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
Expand All @@ -119,11 +97,38 @@
]
)

TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])

TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])

TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
# image params
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])

IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])


# batch params
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])

IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])

TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])

TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])

IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])

IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])

UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])

UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])

TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])

TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])

VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])

# callback params
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
Loading
Loading