Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c8b5d56
make loader optional
yiyixuxu May 1, 2025
7b86fce
remove lora step and ip-adapter step -> no longer needed
yiyixuxu May 2, 2025
7ca860c
rename pipeline -> components, data -> block_state
yiyixuxu May 2, 2025
efd70b7
seperate controlnet step into input + denoise
yiyixuxu May 3, 2025
43ac1ff
refactor controlnet union
yiyixuxu May 4, 2025
dc4dbfe
reefactor pipeline/block states so that it can dynamically accept kwargs
yiyixuxu May 6, 2025
f552773
remove controlnet union denoise step, refactor & reuse controlnet den…
yiyixuxu May 6, 2025
16b6583
allow input_fields as input & update message
yiyixuxu May 8, 2025
d89631f
update input formating, consider kwarggs_type inputs with no name, e/…
yiyixuxu May 8, 2025
0f0618f
refactor the denoiseestep using LoopSequential! also add a new file f…
yiyixuxu May 8, 2025
c677d52
change warning to debug
yiyixuxu May 9, 2025
2b361a2
fix get_execusion blocks with loopsequential
yiyixuxu May 9, 2025
2017ae5
fix auto denoise so all tests pass
yiyixuxu May 9, 2025
cf01aae
update imports on guiders
yiyixuxu May 10, 2025
462429b
remove modular reelated change from pipelines folder
yiyixuxu May 10, 2025
0acb5e1
made a modular_pipelines folder!
yiyixuxu May 10, 2025
153ae34
update __init__
yiyixuxu May 10, 2025
796453c
add notes
yiyixuxu May 11, 2025
144eae4
add block state will also make sure modifed intermediates_inputs will…
yiyixuxu May 11, 2025
522e827
move block mappings to its own file
yiyixuxu May 11, 2025
5cde77f
make inputs truly immutable, remove the output logic in sequential pi…
yiyixuxu May 12, 2025
58358c2
decode block, if skip decoding do not need to update latent
yiyixuxu May 12, 2025
506a8ea
fix imports
yiyixuxu May 13, 2025
e2491af
fix import
yiyixuxu May 13, 2025
a0deefb
fix more
yiyixuxu May 13, 2025
a7fb2d2
remove the output step
yiyixuxu May 13, 2025
8ad14a5
make generator intermediates (it is mutable)
yiyixuxu May 13, 2025
96ce674
after_denoise -> decoders
yiyixuxu May 14, 2025
27c1158
add a to-do for guider cconfig mixin
yiyixuxu May 18, 2025
d0fbf74
refactor component spec: replace create/create_from_pretrained/create…
yiyixuxu May 18, 2025
163341d
refactor modular loader: 1. load only load (pretrained components onl…
yiyixuxu May 18, 2025
73ab572
update components manager
yiyixuxu May 18, 2025
61dac3b
up
yiyixuxu May 19, 2025
4968edc
remove the duplicated components_manager file I forgot to deletee
yiyixuxu May 20, 2025
de6ab6b
fix import in block mapping
yiyixuxu May 20, 2025
eb94150
add a to-do for modular loader
yiyixuxu May 20, 2025
1b89ac1
prepare_latents_img2img pipeline method -> function, maybe do the sam…
yiyixuxu May 20, 2025
d136ae3
update input for loop blocks, do not need to include intermediate
yiyixuxu May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"modular_pipelines": [],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
Expand Down Expand Up @@ -254,13 +255,19 @@
"KarrasVePipeline",
"LDMPipeline",
"LDMSuperResolutionPipeline",
"ModularLoader",
"PNDMPipeline",
"RePaintPipeline",
"ScoreSdeVePipeline",
"StableDiffusionMixin",
]
)
_import_structure["modular_pipelines"].extend(
[
"ModularLoader",
"ComponentSpec",
"ComponentsManager",
]
)
_import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
Expand Down Expand Up @@ -509,12 +516,10 @@
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLModularLoader",
"StableDiffusionXLPAGImg2ImgPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLAutoPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
Expand All @@ -541,6 +546,24 @@
]
)


try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_objects"] = [
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
]

else:
_import_structure["modular_pipelines"].extend(
[
"StableDiffusionXLAutoPipeline",
"StableDiffusionXLModularLoader",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -864,12 +887,16 @@
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
ModularLoader,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .modular_pipelines import (
ModularLoader,
ComponentSpec,
ComponentsManager,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
Expand Down Expand Up @@ -1097,12 +1124,10 @@
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularLoader,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
Expand All @@ -1127,7 +1152,16 @@
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)

try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple

import torch

from .guider_utils import BaseGuidance, rescale_noise_cfg

if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


class AdaptiveProjectedGuidance(BaseGuidance):
Expand Down Expand Up @@ -73,14 +73,18 @@ def __init__(
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None

def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:

if input_fields is None:
input_fields = self._input_fields

if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/guiders/auto_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple

import torch

Expand All @@ -22,7 +22,7 @@
from .guider_utils import BaseGuidance, rescale_noise_cfg

if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


class AutoGuidance(BaseGuidance):
Expand Down Expand Up @@ -120,11 +120,15 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)

def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:

if input_fields is None:
input_fields = self._input_fields

tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/guiders/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple

import torch

from .guider_utils import BaseGuidance, rescale_noise_cfg

if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


class ClassifierFreeGuidance(BaseGuidance):
Expand Down Expand Up @@ -75,11 +75,15 @@ def __init__(
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation

def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:

if input_fields is None:
input_fields = self._input_fields

tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/guiders/classifier_free_zero_star_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple

import torch

from .guider_utils import BaseGuidance, rescale_noise_cfg

if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


class ClassifierFreeZeroStarGuidance(BaseGuidance):
Expand Down Expand Up @@ -73,11 +73,15 @@ def __init__(
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation

def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:

if input_fields is None:
input_fields = self._input_fields

tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/guiders/guider_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


logger = get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -171,10 +171,10 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da
Returns:
`BlockState`: The prepared batch of data.
"""
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState

if input_fields is None:
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
data_batch = {}
for key, value in input_fields.items():
try:
Expand All @@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)

Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/guiders/skip_layer_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple

import torch

Expand All @@ -22,7 +22,7 @@
from .guider_utils import BaseGuidance, rescale_noise_cfg

if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


class SkipLayerGuidance(BaseGuidance):
Expand Down Expand Up @@ -156,7 +156,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)

def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:

if input_fields is None:
input_fields = self._input_fields

if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
Expand All @@ -168,7 +172,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple

import torch

Expand All @@ -22,7 +22,7 @@
from .guider_utils import BaseGuidance, rescale_noise_cfg

if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState
from ..modular_pipelines.modular_pipeline import BlockState


class SmoothedEnergyGuidance(BaseGuidance):
Expand Down Expand Up @@ -149,7 +149,11 @@ def cleanup_models(self, denoiser: torch.nn.Module):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)

def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:

if input_fields is None:
input_fields = self._input_fields

if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
Expand All @@ -161,7 +165,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

Expand Down
Loading