Skip to content

Commit 1bea2d8

Browse files
committed
style
1 parent 4b367e8 commit 1bea2d8

File tree

5 files changed

+116
-122
lines changed

5 files changed

+116
-122
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
import inspect
1616
from typing import Any, List, Optional, Tuple, Union
1717

18-
import PIL
1918
import torch
2019

2120
from ...configuration_utils import FrozenDict
22-
from ...guiders import ClassifierFreeGuidance
2321
from ...image_processor import VaeImageProcessor
24-
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
22+
from ...models import ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
2523
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
2624
from ...schedulers import EulerDiscreteScheduler
2725
from ...utils import logging
@@ -591,7 +589,11 @@ def intermediate_inputs(self) -> List[str]:
591589
type_hint=torch.Tensor,
592590
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
593591
),
594-
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
592+
InputParam(
593+
"dtype",
594+
type_hint=torch.dtype,
595+
description="The dtype of the model inputs, can be generated in input step.",
596+
),
595597
]
596598

597599
@property
@@ -618,7 +620,6 @@ def prepare_latents(
618620
is_strength_max=True,
619621
add_noise=True,
620622
):
621-
622623
batch_size = image_latents.shape[0]
623624

624625
if isinstance(generator, list) and len(generator) != batch_size:
@@ -640,46 +641,50 @@ def prepare_latents(
640641

641642
return latents, noise
642643

643-
644-
645644
def check_inputs(self, batch_size, image_latents, mask, masked_image_latents):
646-
647645
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
648-
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
646+
raise ValueError(
647+
f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}"
648+
)
649649

650650
if not (mask.shape[0] == 1 or mask.shape[0] == batch_size):
651651
raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}")
652-
652+
653653
if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size):
654-
raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}")
655-
656-
654+
raise ValueError(
655+
f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}"
656+
)
657+
657658
@torch.no_grad()
658659
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
659660
block_state = self.get_block_state(state)
660-
661+
661662
self.check_inputs(
662663
batch_size=block_state.batch_size,
663-
image_latents=block_state.image_latents,
664-
mask=block_state.mask,
665-
masked_image_latents=block_state.masked_image_latents,
666-
)
664+
image_latents=block_state.image_latents,
665+
mask=block_state.mask,
666+
masked_image_latents=block_state.masked_image_latents,
667+
)
667668

668669
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
669670
device = components._execution_device
670-
671-
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
672-
671+
672+
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
673+
673674
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
674-
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
675+
block_state.image_latents = block_state.image_latents.repeat(
676+
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1
677+
)
675678

676679
# 7. Prepare mask latent variables
677680
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
678-
block_state.mask = block_state.mask.repeat(final_batch_size//block_state.mask.shape[0], 1, 1, 1)
681+
block_state.mask = block_state.mask.repeat(final_batch_size // block_state.mask.shape[0], 1, 1, 1)
679682

680683
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
681-
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1)
682-
684+
block_state.masked_image_latents = block_state.masked_image_latents.repeat(
685+
final_batch_size // block_state.masked_image_latents.shape[0], 1, 1, 1
686+
)
687+
683688
if block_state.latent_timestep is not None:
684689
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
685690
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
@@ -698,7 +703,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
698703
add_noise=add_noise,
699704
)
700705

701-
702706
self.set_block_state(state, block_state)
703707

704708
return components, state
@@ -755,11 +759,13 @@ def intermediate_outputs(self) -> List[OutputParam]:
755759
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
756760
)
757761
]
758-
762+
759763
def check_inputs(self, batch_size, image_latents):
760764
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
761-
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
762-
765+
raise ValueError(
766+
f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}"
767+
)
768+
763769
@staticmethod
764770
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
765771
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
@@ -788,7 +794,9 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
788794
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
789795

790796
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
791-
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
797+
block_state.image_latents = block_state.image_latents.repeat(
798+
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1
799+
)
792800

793801
if block_state.latent_timestep is not None:
794802
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
@@ -935,7 +943,9 @@ def expected_configs(self) -> List[ConfigSpec]:
935943

936944
@property
937945
def expected_components(self) -> List[ComponentSpec]:
938-
return [ComponentSpec("unet", UNet2DConditionModel),]
946+
return [
947+
ComponentSpec("unet", UNet2DConditionModel),
948+
]
939949

940950
@property
941951
def description(self) -> str:
@@ -976,7 +986,11 @@ def intermediate_inputs(self) -> List[InputParam]:
976986
type_hint=int,
977987
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
978988
),
979-
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
989+
InputParam(
990+
"dtype",
991+
type_hint=torch.dtype,
992+
description="The dtype of the model inputs, can be generated in input step.",
993+
),
980994
]
981995

982996
@property
@@ -1052,7 +1066,7 @@ def _get_add_time_ids(
10521066
@torch.no_grad()
10531067
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
10541068
block_state = self.get_block_state(state)
1055-
1069+
10561070
device = components._execution_device
10571071
dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype
10581072

@@ -1087,7 +1101,9 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
10871101
text_encoder_projection_dim=text_encoder_projection_dim,
10881102
)
10891103
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
1090-
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
1104+
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(
1105+
device=device
1106+
)
10911107

10921108
self.set_block_state(state, block_state)
10931109
return components, state
@@ -1102,7 +1118,9 @@ def description(self) -> str:
11021118

11031119
@property
11041120
def expected_components(self) -> List[ComponentSpec]:
1105-
return [ComponentSpec("unet", UNet2DConditionModel),]
1121+
return [
1122+
ComponentSpec("unet", UNet2DConditionModel),
1123+
]
11061124

11071125
@property
11081126
def inputs(self) -> List[Tuple[str, Any]]:
@@ -1196,7 +1214,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
11961214
original_size = block_state.original_size or (height, width)
11971215
target_size = block_state.target_size or (height, width)
11981216

1199-
12001217
block_state.add_time_ids = self._get_add_time_ids(
12011218
components,
12021219
original_size,
@@ -1218,7 +1235,9 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
12181235
block_state.negative_add_time_ids = block_state.add_time_ids
12191236

12201237
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
1221-
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
1238+
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(
1239+
device=device
1240+
)
12221241

12231242
self.set_block_state(state, block_state)
12241243
return components, state
@@ -1229,7 +1248,9 @@ class StableDiffusionXLLCMStep(PipelineBlock):
12291248

12301249
@property
12311250
def expected_components(self) -> List[ComponentSpec]:
1232-
return [ComponentSpec("unet", UNet2DConditionModel),]
1251+
return [
1252+
ComponentSpec("unet", UNet2DConditionModel),
1253+
]
12331254

12341255
@property
12351256
def description(self) -> str:
@@ -1290,30 +1311,30 @@ def get_guidance_scale_embedding(
12901311
assert emb.shape == (w.shape[0], embedding_dim)
12911312
return emb
12921313

1293-
12941314
def check_input(self, unet, embedded_guidance_scale):
1295-
12961315
if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None:
1297-
raise ValueError(f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None")
1316+
raise ValueError(
1317+
f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None"
1318+
)
12981319

12991320
if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None:
1300-
raise ValueError(f"unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
1321+
raise ValueError("unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
13011322

1302-
13031323
@torch.no_grad()
13041324
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
13051325
block_state = self.get_block_state(state)
1306-
1326+
13071327
device = components._execution_device
13081328
dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype
13091329

13101330
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
13111331

1312-
13131332
# Optionally get Guidance Scale Embedding for LCM
13141333
block_state.timestep_cond = None
1315-
1316-
guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
1334+
1335+
guidance_scale_tensor = (
1336+
torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
1337+
)
13171338
block_state.timestep_cond = self.get_guidance_scale_embedding(
13181339
guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
13191340
).to(device=device, dtype=dtype)
@@ -1476,9 +1497,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
14761497
if isinstance(controlnet, MultiControlNetModel) and isinstance(
14771498
block_state.controlnet_conditioning_scale, float
14781499
):
1479-
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
1480-
controlnet.nets
1481-
)
1500+
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
14821501
else:
14831502
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
14841503

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
130130
latents_std = (
131131
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
132132
)
133-
latents = (
134-
latents * latents_std / components.vae.config.scaling_factor + latents_mean
135-
)
133+
latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean
136134
else:
137135
latents = latents / components.vae.config.scaling_factor
138136

0 commit comments

Comments
 (0)