Skip to content

Commit 01300a3

Browse files
committed
up
1 parent 65ba892 commit 01300a3

File tree

3 files changed

+46
-24
lines changed

3 files changed

+46
-24
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import torch
2020

2121
from ...configuration_utils import FrozenDict
22+
from ...guiders import ClassifierFreeGuidance
2223
from ...image_processor import VaeImageProcessor
23-
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel
24+
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
2425
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
2526
from ...schedulers import EulerDiscreteScheduler
2627
from ...utils import logging
@@ -266,37 +267,37 @@ def intermediate_outputs(self) -> List[str]:
266267
OutputParam(
267268
"prompt_embeds",
268269
type_hint=torch.Tensor,
269-
kwargs_type="guider_input_fields",
270+
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
270271
description="text embeddings used to guide the image generation",
271272
),
272273
OutputParam(
273274
"negative_prompt_embeds",
274275
type_hint=torch.Tensor,
275-
kwargs_type="guider_input_fields",
276+
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
276277
description="negative text embeddings used to guide the image generation",
277278
),
278279
OutputParam(
279280
"pooled_prompt_embeds",
280281
type_hint=torch.Tensor,
281-
kwargs_type="guider_input_fields",
282+
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
282283
description="pooled text embeddings used to guide the image generation",
283284
),
284285
OutputParam(
285286
"negative_pooled_prompt_embeds",
286287
type_hint=torch.Tensor,
287-
kwargs_type="guider_input_fields",
288+
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
288289
description="negative pooled text embeddings used to guide the image generation",
289290
),
290291
OutputParam(
291292
"ip_adapter_embeds",
292293
type_hint=List[torch.Tensor],
293-
kwargs_type="guider_input_fields",
294+
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
294295
description="image embeddings for IP-Adapter",
295296
),
296297
OutputParam(
297298
"negative_ip_adapter_embeds",
298299
type_hint=List[torch.Tensor],
299-
kwargs_type="guider_input_fields",
300+
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
300301
description="negative image embeddings for IP-Adapter",
301302
),
302303
]
@@ -683,12 +684,6 @@ def intermediate_outputs(self) -> List[str]:
683684
OutputParam(
684685
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
685686
),
686-
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"),
687-
OutputParam(
688-
"masked_image_latents",
689-
type_hint=torch.Tensor,
690-
description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)",
691-
),
692687
OutputParam(
693688
"noise",
694689
type_hint=torch.Tensor,
@@ -993,6 +988,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
993988
def expected_components(self) -> List[ComponentSpec]:
994989
return [
995990
ComponentSpec("scheduler", EulerDiscreteScheduler),
991+
ComponentSpec("vae", AutoencoderKL),
996992
]
997993

998994
@property
@@ -1105,6 +1101,18 @@ def expected_configs(self) -> List[ConfigSpec]:
11051101
ConfigSpec("requires_aesthetics_score", False),
11061102
]
11071103

1104+
@property
1105+
def expected_components(self) -> List[ComponentSpec]:
1106+
return [
1107+
ComponentSpec("unet", UNet2DConditionModel),
1108+
ComponentSpec(
1109+
"guider",
1110+
ClassifierFreeGuidance,
1111+
config=FrozenDict({"guidance_scale": 7.5}),
1112+
default_creation_method="from_config",
1113+
),
1114+
]
1115+
11081116
@property
11091117
def description(self) -> str:
11101118
return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process"
@@ -1315,6 +1323,18 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
13151323
def description(self) -> str:
13161324
return "Step that prepares the additional conditioning for the text-to-image generation process"
13171325

1326+
@property
1327+
def expected_components(self) -> List[ComponentSpec]:
1328+
return [
1329+
ComponentSpec("unet", UNet2DConditionModel),
1330+
ComponentSpec(
1331+
"guider",
1332+
ClassifierFreeGuidance,
1333+
config=FrozenDict({"guidance_scale": 7.5}),
1334+
default_creation_method="from_config",
1335+
),
1336+
]
1337+
13181338
@property
13191339
def inputs(self) -> List[Tuple[str, Any]]:
13201340
return [

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,17 @@ def description(self) -> str:
167167
+ "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
168168
)
169169

170+
@property
171+
def expected_components(self) -> List[ComponentSpec]:
172+
return [
173+
ComponentSpec(
174+
"image_processor",
175+
VaeImageProcessor,
176+
config=FrozenDict({"vae_scale_factor": 8}),
177+
default_creation_method="from_config",
178+
),
179+
]
180+
170181
@property
171182
def inputs(self) -> List[Tuple[str, Any]]:
172183
return [
@@ -190,16 +201,6 @@ def intermediate_inputs(self) -> List[str]:
190201
),
191202
]
192203

193-
@property
194-
def intermediate_outputs(self) -> List[str]:
195-
return [
196-
OutputParam(
197-
"images",
198-
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
199-
description="The generated images with the mask overlayed",
200-
)
201-
]
202-
203204
@torch.no_grad()
204205
def __call__(self, components, state: PipelineState) -> PipelineState:
205206
block_state = self.get_block_state(state)

src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def num_channels_latents(self):
9191
return num_channels_latents
9292

9393

94-
# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks
94+
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
95+
# auto_docstring
9596
SDXL_INPUTS_SCHEMA = {
9697
"prompt": InputParam(
9798
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"

0 commit comments

Comments
 (0)