Skip to content

Commit db94ca8

Browse files
committed
add controlnet inpaint + more refactor
1 parent 6985906 commit db94ca8

File tree

2 files changed

+141
-98
lines changed

2 files changed

+141
-98
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
305305
block = self.trigger_to_block_map[input_name]
306306
break
307307

308+
if block is None:
309+
logger.warning(f"skipping auto block: {self.__class__.__name__}")
310+
return pipeline, state
311+
308312
try:
309313
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
310314
return block(pipeline, state)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 137 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import PIL
1919
import torch
20+
from collections import OrderedDict
2021

2122
from ...guider import CFGGuider
2223
from ...image_processor import VaeImageProcessor
@@ -122,64 +123,6 @@ def retrieve_latents(
122123
raise AttributeError("Could not access latents of provided encoder_output")
123124

124125

125-
class StableDiffusionXLOutputStep(PipelineBlock):
126-
model_name = "stable-diffusion-xl"
127-
128-
@property
129-
def inputs(self) -> List[Tuple[str, Any]]:
130-
return [("return_dict", True)]
131-
132-
@property
133-
def intermediates_outputs(self) -> List[str]:
134-
return ["images"]
135-
136-
@torch.no_grad()
137-
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
138-
images = state.get_intermediate("images")
139-
return_dict = state.get_input("return_dict")
140-
141-
if not return_dict:
142-
output = (images,)
143-
else:
144-
output = StableDiffusionXLPipelineOutput(images=images)
145-
state.add_output("images", output)
146-
return pipeline, state
147-
148-
149-
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
150-
model_name = "stable-diffusion-xl"
151-
152-
@property
153-
def inputs(self) -> List[Tuple[str, Any]]:
154-
return [
155-
("image", None),
156-
("mask_image", None),
157-
("padding_mask_crop", None),
158-
]
159-
160-
@property
161-
def intermediates_inputs(self) -> List[str]:
162-
return ["crops_coords", "images"]
163-
164-
@property
165-
def intermediates_outputs(self) -> List[str]:
166-
return ["images"]
167-
168-
@torch.no_grad()
169-
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
170-
original_image = state.get_input("image")
171-
padding_mask_crop = state.get_input("padding_mask_crop")
172-
mask_image = state.get_input("mask_image")
173-
images = state.get_intermediate("images")
174-
crops_coords = state.get_intermediate("crops_coords")
175-
176-
if padding_mask_crop is not None and crops_coords is not None:
177-
images = [pipeline.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in images]
178-
179-
state.add_intermediate("images", images)
180-
181-
return pipeline, state
182-
183126

184127
class StableDiffusionXLInputStep(PipelineBlock):
185128
model_name = "stable-diffusion-xl"
@@ -376,7 +319,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
376319
return pipeline, state
377320

378321

379-
class StableDiffusionXLVAEEncoderStep(PipelineBlock):
322+
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
380323
expected_components = ["vae"]
381324
model_name = "stable-diffusion-xl"
382325

@@ -589,7 +532,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
589532
return pipeline, state
590533

591534

592-
class StableDiffusionXLInpaintVaeEncodeStep(PipelineBlock):
535+
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
593536
expected_components = ["vae"]
594537
model_name = "stable-diffusion-xl"
595538

@@ -694,7 +637,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
694637
return pipeline, state
695638

696639

697-
# inpaint-specific
698640
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
699641
expected_components = ["scheduler"]
700642
model_name = "stable-diffusion-xl"
@@ -804,27 +746,22 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
804746
@property
805747
def inputs(self) -> List[Tuple[str, Any]]:
806748
return [
807-
("height", None),
808-
("width", None),
809749
("generator", None),
810750
("latents", None),
811751
("num_images_per_prompt", 1),
812-
("image", None),
813752
("denoising_start", None),
814753
]
815754

816755
@property
817756
def intermediates_inputs(self) -> List[str]:
818-
return ["batch_size", "dtype", "latent_timestep"]
757+
return ["batch_size", "dtype", "latent_timestep", "image_latents"]
819758

820759
@property
821760
def intermediates_outputs(self) -> List[str]:
822761
return ["latents"]
823762

824763
def __init__(self):
825764
super().__init__()
826-
self.auxiliaries["image_processor"] = VaeImageProcessor()
827-
self.components["vae"] = None
828765
self.components["scheduler"] = None
829766

830767
@torch.no_grad()
@@ -834,24 +771,22 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
834771
generator = state.get_input("generator")
835772

836773
# image to image only
837-
image = state.get_input("image")
838774
denoising_start = state.get_input("denoising_start")
839775

840776
batch_size = state.get_intermediate("batch_size")
841777
dtype = state.get_intermediate("dtype")
842778
# image to image only
843779
latent_timestep = state.get_intermediate("latent_timestep")
780+
image_latents = state.get_intermediate("image_latents")
844781

845782
if dtype is None:
846783
dtype = pipeline.vae.dtype
847784

848785
device = pipeline._execution_device
849-
850-
image = pipeline.image_processor.preprocess(image)
851786
add_noise = True if denoising_start is None else False
852787
if latents is None:
853788
latents = pipeline.prepare_latents_img2img(
854-
image,
789+
image_latents,
855790
latent_timestep,
856791
batch_size,
857792
num_images_per_prompt,
@@ -1723,6 +1658,81 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
17231658
return pipeline, state
17241659

17251660

1661+
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
1662+
model_name = "stable-diffusion-xl"
1663+
1664+
@property
1665+
def inputs(self) -> List[Tuple[str, Any]]:
1666+
return [
1667+
("image", None),
1668+
("mask_image", None),
1669+
("padding_mask_crop", None),
1670+
]
1671+
1672+
@property
1673+
def intermediates_inputs(self) -> List[str]:
1674+
return ["crops_coords", "images"]
1675+
1676+
@property
1677+
def intermediates_outputs(self) -> List[str]:
1678+
return ["images"]
1679+
1680+
@torch.no_grad()
1681+
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
1682+
original_image = state.get_input("image")
1683+
padding_mask_crop = state.get_input("padding_mask_crop")
1684+
mask_image = state.get_input("mask_image")
1685+
images = state.get_intermediate("images")
1686+
crops_coords = state.get_intermediate("crops_coords")
1687+
1688+
if padding_mask_crop is not None and crops_coords is not None:
1689+
images = [pipeline.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in images]
1690+
1691+
state.add_intermediate("images", images)
1692+
1693+
return pipeline, state
1694+
1695+
1696+
class StableDiffusionXLOutputStep(PipelineBlock):
1697+
model_name = "stable-diffusion-xl"
1698+
1699+
@property
1700+
def inputs(self) -> List[Tuple[str, Any]]:
1701+
return [("return_dict", True)]
1702+
1703+
@property
1704+
def intermediates_outputs(self) -> List[str]:
1705+
return ["images"]
1706+
1707+
@torch.no_grad()
1708+
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
1709+
images = state.get_intermediate("images")
1710+
return_dict = state.get_input("return_dict")
1711+
1712+
if not return_dict:
1713+
output = (images,)
1714+
else:
1715+
output = StableDiffusionXLPipelineOutput(images=images)
1716+
state.add_output("images", output)
1717+
return pipeline, state
1718+
1719+
1720+
class StableDiffusionXLDecodeStep(SequentialPipelineBlocks):
1721+
block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep]
1722+
block_names = ["decode", "output"]
1723+
1724+
1725+
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
1726+
block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep]
1727+
block_names = ["decode", "mask_overlay", "output"]
1728+
1729+
1730+
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
1731+
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
1732+
block_names = ["inpaint", "img2img"]
1733+
block_trigger_inputs = ["mask_image", "image"]
1734+
1735+
17261736
class StableDiffusionXLAutoSetTimestepsStep(AutoPipelineBlocks):
17271737
block_classes = [StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLSetTimestepsStep]
17281738
block_names = ["img2img", "text2img"]
@@ -1750,38 +1760,67 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
17501760
block_trigger_inputs = ["control_image", None]
17511761

17521762

1753-
class StableDiffusionXLDecodeStep(SequentialPipelineBlocks):
1754-
block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep]
1755-
block_names = ["decode", "output"]
1756-
1757-
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
1758-
block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep]
1759-
block_names = ["decode", "mask_overlay", "output"]
1760-
17611763
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
17621764
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
17631765
block_names = ["inpaint", "non-inpaint"]
17641766
block_trigger_inputs = ["padding_mask_crop", None]
17651767

1766-
class StableDiffusionXLAllSteps(SequentialPipelineBlocks):
1767-
block_classes = [
1768-
StableDiffusionXLInputStep,
1769-
StableDiffusionXLTextEncoderStep,
1770-
StableDiffusionXLAutoSetTimestepsStep,
1771-
StableDiffusionXLAutoPrepareLatentsStep,
1772-
StableDiffusionXLAutoPrepareAdditionalConditioningStep,
1773-
StableDiffusionXLAutoDenoiseStep,
1774-
StableDiffusionXLAutoDecodeStep
1775-
]
1776-
block_names = [
1777-
"input",
1778-
"text_encoder",
1779-
"set_timesteps",
1780-
"prepare_latents",
1781-
"prepare_add_cond",
1782-
"denoise",
1783-
"decode"
1784-
]
1768+
1769+
TEXT2IMAGE_BLOCKS = OrderedDict([
1770+
("input", StableDiffusionXLInputStep),
1771+
("text_encoder", StableDiffusionXLTextEncoderStep),
1772+
("set_timesteps", StableDiffusionXLAutoSetTimestepsStep),
1773+
("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep),
1774+
("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep),
1775+
("denoise", StableDiffusionXLAutoDenoiseStep),
1776+
("decode", StableDiffusionXLDecodeStep)
1777+
])
1778+
1779+
IMAGE2IMAGE_BLOCKS = OrderedDict([
1780+
("input", StableDiffusionXLInputStep),
1781+
("text_encoder", StableDiffusionXLTextEncoderStep),
1782+
("image_encoder", StableDiffusionXLVaeEncoderStep),
1783+
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
1784+
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
1785+
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
1786+
("denoise", StableDiffusionXLDenoiseStep),
1787+
("decode", StableDiffusionXLDecodeStep)
1788+
])
1789+
1790+
INPAINT_BLOCKS = OrderedDict([
1791+
("input", StableDiffusionXLInputStep),
1792+
("text_encoder", StableDiffusionXLTextEncoderStep),
1793+
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
1794+
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
1795+
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
1796+
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
1797+
("denoise", StableDiffusionXLDenoiseStep),
1798+
("decode", StableDiffusionXLInpaintDecodeStep)
1799+
])
1800+
1801+
CONTROLNET_BLOCKS = OrderedDict([
1802+
("denoise", StableDiffusionXLControlNetDenoiseStep),
1803+
])
1804+
1805+
AUTO_BLOCKS = OrderedDict([
1806+
("input", StableDiffusionXLInputStep),
1807+
("text_encoder", StableDiffusionXLTextEncoderStep),
1808+
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
1809+
("set_timesteps", StableDiffusionXLAutoSetTimestepsStep),
1810+
("prepare_latents", StableDiffusionXLAutoPrepareLatentsStep),
1811+
("prepare_add_cond", StableDiffusionXLAutoPrepareAdditionalConditioningStep),
1812+
("denoise", StableDiffusionXLAutoDenoiseStep),
1813+
("decode", StableDiffusionXLAutoDecodeStep)
1814+
])
1815+
1816+
1817+
SDXL_SUPPORTED_BLOCKS = {
1818+
"text2img": TEXT2IMAGE_BLOCKS,
1819+
"img2img": IMAGE2IMAGE_BLOCKS,
1820+
"inpaint": INPAINT_BLOCKS,
1821+
"controlnet": CONTROLNET_BLOCKS,
1822+
"auto": AUTO_BLOCKS
1823+
}
17851824

17861825

17871826
class StableDiffusionXLModularPipeline(

0 commit comments

Comments
 (0)