Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
Expand Down Expand Up @@ -59,7 +58,6 @@
LoopSequentialPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
Expand Down
38 changes: 7 additions & 31 deletions src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline

Expand Down Expand Up @@ -231,7 +231,7 @@ def _get_initial_timesteps_and_optionals(
return timesteps, num_inference_steps, sigmas, guidance


class FluxInputStep(PipelineBlock):
class FluxInputStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -249,11 +249,6 @@ def description(self) -> str:
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
Expand Down Expand Up @@ -322,7 +317,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
return components, state


class FluxSetTimestepsStep(PipelineBlock):
class FluxSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -340,14 +335,10 @@ def inputs(self) -> List[InputParam]:
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"batch_size",
required=True,
Expand Down Expand Up @@ -398,7 +389,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
return components, state


class FluxImg2ImgSetTimestepsStep(PipelineBlock):
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -420,11 +411,6 @@ def inputs(self) -> List[InputParam]:
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"batch_size",
required=True,
Expand Down Expand Up @@ -497,7 +483,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
return components, state


class FluxPrepareLatentsStep(PipelineBlock):
class FluxPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -515,11 +501,6 @@ def inputs(self) -> List[InputParam]:
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
]

@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
Expand Down Expand Up @@ -621,7 +602,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
return components, state


class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -639,11 +620,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
]

@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"image_latents",
Expand Down
11 changes: 3 additions & 8 deletions src/diffusers/modular_pipelines/flux/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...models import AutoencoderKL
from ...utils import logging
from ...video_processor import VaeImageProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam


Expand All @@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
return latents


class FluxDecodeStep(PipelineBlock):
class FluxDecodeStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -70,17 +70,12 @@ def inputs(self) -> List[Tuple[str, Any]]:
InputParam("output_type", default="pil"),
InputParam("height", default=1024),
InputParam("width", default=1024),
]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
]

@property
Expand Down
13 changes: 5 additions & 8 deletions src/diffusers/modular_pipelines/flux/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
Expand All @@ -32,7 +32,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class FluxLoopDenoiser(PipelineBlock):
class FluxLoopDenoiser(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -49,11 +49,8 @@ def description(self) -> str:

@property
def inputs(self) -> List[Tuple[str, Any]]:
return [InputParam("joint_attention_kwargs")]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
Expand Down Expand Up @@ -113,7 +110,7 @@ def __call__(
return components, block_state


class FluxLoopAfterDenoiser(PipelineBlock):
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"

@property
Expand Down Expand Up @@ -175,7 +172,7 @@ def loop_expected_components(self) -> List[ComponentSpec]:
]

@property
def loop_intermediate_inputs(self) -> List[InputParam]:
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
Expand Down
13 changes: 6 additions & 7 deletions src/diffusers/modular_pipelines/flux/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline

Expand Down Expand Up @@ -67,7 +67,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


class FluxVaeEncoderStep(PipelineBlock):
class FluxVaeEncoderStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand All @@ -88,11 +88,10 @@ def expected_components(self) -> List[ComponentSpec]:

@property
def inputs(self) -> List[InputParam]:
return [InputParam("image", required=True), InputParam("height"), InputParam("width")]

@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
Expand Down Expand Up @@ -157,7 +156,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
return components, state


class FluxTextEncoderStep(PipelineBlock):
class FluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux"

@property
Expand Down
Loading