1515import inspect
1616from typing import Any , List , Optional , Tuple , Union
1717
18- import PIL
1918import torch
2019
2120from ...configuration_utils import FrozenDict
22- from ...guiders import ClassifierFreeGuidance
2321from ...image_processor import VaeImageProcessor
24- from ...models import AutoencoderKL , ControlNetModel , ControlNetUnionModel , UNet2DConditionModel
22+ from ...models import ControlNetModel , ControlNetUnionModel , UNet2DConditionModel
2523from ...pipelines .controlnet .multicontrolnet import MultiControlNetModel
2624from ...schedulers import EulerDiscreteScheduler
2725from ...utils import logging
@@ -591,7 +589,11 @@ def intermediate_inputs(self) -> List[str]:
591589 type_hint = torch .Tensor ,
592590 description = "The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." ,
593591 ),
594- InputParam ("dtype" , type_hint = torch .dtype , description = "The dtype of the model inputs, can be generated in input step." ),
592+ InputParam (
593+ "dtype" ,
594+ type_hint = torch .dtype ,
595+ description = "The dtype of the model inputs, can be generated in input step." ,
596+ ),
595597 ]
596598
597599 @property
@@ -618,7 +620,6 @@ def prepare_latents(
618620 is_strength_max = True ,
619621 add_noise = True ,
620622 ):
621-
622623 batch_size = image_latents .shape [0 ]
623624
624625 if isinstance (generator , list ) and len (generator ) != batch_size :
@@ -640,46 +641,50 @@ def prepare_latents(
640641
641642 return latents , noise
642643
643-
644-
645644 def check_inputs (self , batch_size , image_latents , mask , masked_image_latents ):
646-
647645 if not (image_latents .shape [0 ] == 1 or image_latents .shape [0 ] == batch_size ):
648- raise ValueError (f"image_latents should have have batch size 1 or { batch_size } , but got { image_latents .shape [0 ]} " )
646+ raise ValueError (
647+ f"image_latents should have have batch size 1 or { batch_size } , but got { image_latents .shape [0 ]} "
648+ )
649649
650650 if not (mask .shape [0 ] == 1 or mask .shape [0 ] == batch_size ):
651651 raise ValueError (f"mask should have have batch size 1 or { batch_size } , but got { mask .shape [0 ]} " )
652-
652+
653653 if not (masked_image_latents .shape [0 ] == 1 or masked_image_latents .shape [0 ] == batch_size ):
654- raise ValueError (f"masked_image_latents should have have batch size 1 or { batch_size } , but got { masked_image_latents .shape [0 ]} " )
655-
656-
654+ raise ValueError (
655+ f"masked_image_latents should have have batch size 1 or { batch_size } , but got { masked_image_latents .shape [0 ]} "
656+ )
657+
657658 @torch .no_grad ()
658659 def __call__ (self , components : StableDiffusionXLModularPipeline , state : PipelineState ) -> PipelineState :
659660 block_state = self .get_block_state (state )
660-
661+
661662 self .check_inputs (
662663 batch_size = block_state .batch_size ,
663- image_latents = block_state .image_latents ,
664- mask = block_state .mask ,
665- masked_image_latents = block_state .masked_image_latents ,
666- )
664+ image_latents = block_state .image_latents ,
665+ mask = block_state .mask ,
666+ masked_image_latents = block_state .masked_image_latents ,
667+ )
667668
668669 dtype = block_state .dtype if block_state .dtype is not None else block_state .image_latents .dtype
669670 device = components ._execution_device
670-
671- final_batch_size = block_state .batch_size * block_state .num_images_per_prompt
672-
671+
672+ final_batch_size = block_state .batch_size * block_state .num_images_per_prompt
673+
673674 block_state .image_latents = block_state .image_latents .to (device = device , dtype = dtype )
674- block_state .image_latents = block_state .image_latents .repeat (final_batch_size // block_state .image_latents .shape [0 ], 1 , 1 , 1 )
675+ block_state .image_latents = block_state .image_latents .repeat (
676+ final_batch_size // block_state .image_latents .shape [0 ], 1 , 1 , 1
677+ )
675678
676679 # 7. Prepare mask latent variables
677680 block_state .mask = block_state .mask .to (device = device , dtype = dtype )
678- block_state .mask = block_state .mask .repeat (final_batch_size // block_state .mask .shape [0 ], 1 , 1 , 1 )
681+ block_state .mask = block_state .mask .repeat (final_batch_size // block_state .mask .shape [0 ], 1 , 1 , 1 )
679682
680683 block_state .masked_image_latents = block_state .masked_image_latents .to (device = device , dtype = dtype )
681- block_state .masked_image_latents = block_state .masked_image_latents .repeat (final_batch_size // block_state .masked_image_latents .shape [0 ], 1 , 1 , 1 )
682-
684+ block_state .masked_image_latents = block_state .masked_image_latents .repeat (
685+ final_batch_size // block_state .masked_image_latents .shape [0 ], 1 , 1 , 1
686+ )
687+
683688 if block_state .latent_timestep is not None :
684689 block_state .latent_timestep = block_state .latent_timestep .repeat (final_batch_size )
685690 block_state .latent_timestep = block_state .latent_timestep .to (device = device , dtype = dtype )
@@ -698,7 +703,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
698703 add_noise = add_noise ,
699704 )
700705
701-
702706 self .set_block_state (state , block_state )
703707
704708 return components , state
@@ -755,11 +759,13 @@ def intermediate_outputs(self) -> List[OutputParam]:
755759 "latents" , type_hint = torch .Tensor , description = "The initial latents to use for the denoising process"
756760 )
757761 ]
758-
762+
759763 def check_inputs (self , batch_size , image_latents ):
760764 if not (image_latents .shape [0 ] == 1 or image_latents .shape [0 ] == batch_size ):
761- raise ValueError (f"image_latents should have have batch size 1 or { batch_size } , but got { image_latents .shape [0 ]} " )
762-
765+ raise ValueError (
766+ f"image_latents should have have batch size 1 or { batch_size } , but got { image_latents .shape [0 ]} "
767+ )
768+
763769 @staticmethod
764770 def prepare_latents (image_latents , scheduler , timestep , dtype , device , generator = None ):
765771 if isinstance (generator , list ) and len (generator ) != image_latents .shape [0 ]:
@@ -788,7 +794,9 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
788794 final_batch_size = block_state .batch_size * block_state .num_images_per_prompt
789795
790796 block_state .image_latents = block_state .image_latents .to (device = device , dtype = dtype )
791- block_state .image_latents = block_state .image_latents .repeat (final_batch_size // block_state .image_latents .shape [0 ], 1 , 1 , 1 )
797+ block_state .image_latents = block_state .image_latents .repeat (
798+ final_batch_size // block_state .image_latents .shape [0 ], 1 , 1 , 1
799+ )
792800
793801 if block_state .latent_timestep is not None :
794802 block_state .latent_timestep = block_state .latent_timestep .repeat (final_batch_size )
@@ -935,7 +943,9 @@ def expected_configs(self) -> List[ConfigSpec]:
935943
936944 @property
937945 def expected_components (self ) -> List [ComponentSpec ]:
938- return [ComponentSpec ("unet" , UNet2DConditionModel ),]
946+ return [
947+ ComponentSpec ("unet" , UNet2DConditionModel ),
948+ ]
939949
940950 @property
941951 def description (self ) -> str :
@@ -976,7 +986,11 @@ def intermediate_inputs(self) -> List[InputParam]:
976986 type_hint = int ,
977987 description = "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ,
978988 ),
979- InputParam ("dtype" , type_hint = torch .dtype , description = "The dtype of the model inputs, can be generated in input step." ),
989+ InputParam (
990+ "dtype" ,
991+ type_hint = torch .dtype ,
992+ description = "The dtype of the model inputs, can be generated in input step." ,
993+ ),
980994 ]
981995
982996 @property
@@ -1052,7 +1066,7 @@ def _get_add_time_ids(
10521066 @torch .no_grad ()
10531067 def __call__ (self , components : StableDiffusionXLModularPipeline , state : PipelineState ) -> PipelineState :
10541068 block_state = self .get_block_state (state )
1055-
1069+
10561070 device = components ._execution_device
10571071 dtype = block_state .dtype if block_state .dtype is not None else block_state .pooled_prompt_embeds .dtype
10581072
@@ -1087,7 +1101,9 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
10871101 text_encoder_projection_dim = text_encoder_projection_dim ,
10881102 )
10891103 block_state .add_time_ids = block_state .add_time_ids .repeat (final_batch_size , 1 ).to (device = device )
1090- block_state .negative_add_time_ids = block_state .negative_add_time_ids .repeat (final_batch_size , 1 ).to (device = device )
1104+ block_state .negative_add_time_ids = block_state .negative_add_time_ids .repeat (final_batch_size , 1 ).to (
1105+ device = device
1106+ )
10911107
10921108 self .set_block_state (state , block_state )
10931109 return components , state
@@ -1102,7 +1118,9 @@ def description(self) -> str:
11021118
11031119 @property
11041120 def expected_components (self ) -> List [ComponentSpec ]:
1105- return [ComponentSpec ("unet" , UNet2DConditionModel ),]
1121+ return [
1122+ ComponentSpec ("unet" , UNet2DConditionModel ),
1123+ ]
11061124
11071125 @property
11081126 def inputs (self ) -> List [Tuple [str , Any ]]:
@@ -1196,7 +1214,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
11961214 original_size = block_state .original_size or (height , width )
11971215 target_size = block_state .target_size or (height , width )
11981216
1199-
12001217 block_state .add_time_ids = self ._get_add_time_ids (
12011218 components ,
12021219 original_size ,
@@ -1218,7 +1235,9 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
12181235 block_state .negative_add_time_ids = block_state .add_time_ids
12191236
12201237 block_state .add_time_ids = block_state .add_time_ids .repeat (final_batch_size , 1 ).to (device = device )
1221- block_state .negative_add_time_ids = block_state .negative_add_time_ids .repeat (final_batch_size , 1 ).to (device = device )
1238+ block_state .negative_add_time_ids = block_state .negative_add_time_ids .repeat (final_batch_size , 1 ).to (
1239+ device = device
1240+ )
12221241
12231242 self .set_block_state (state , block_state )
12241243 return components , state
@@ -1229,7 +1248,9 @@ class StableDiffusionXLLCMStep(PipelineBlock):
12291248
12301249 @property
12311250 def expected_components (self ) -> List [ComponentSpec ]:
1232- return [ComponentSpec ("unet" , UNet2DConditionModel ),]
1251+ return [
1252+ ComponentSpec ("unet" , UNet2DConditionModel ),
1253+ ]
12331254
12341255 @property
12351256 def description (self ) -> str :
@@ -1290,30 +1311,30 @@ def get_guidance_scale_embedding(
12901311 assert emb .shape == (w .shape [0 ], embedding_dim )
12911312 return emb
12921313
1293-
12941314 def check_input (self , unet , embedded_guidance_scale ):
1295-
12961315 if embedded_guidance_scale is not None and unet .config .time_cond_proj_dim is None :
1297- raise ValueError (f"cannot use `embedded_guidance_scale` { embedded_guidance_scale } because unet.config.time_cond_proj_dim is None" )
1316+ raise ValueError (
1317+ f"cannot use `embedded_guidance_scale` { embedded_guidance_scale } because unet.config.time_cond_proj_dim is None"
1318+ )
12981319
12991320 if embedded_guidance_scale is None and unet .config .time_cond_proj_dim is not None :
1300- raise ValueError (f "unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None" )
1321+ raise ValueError ("unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None" )
13011322
1302-
13031323 @torch .no_grad ()
13041324 def __call__ (self , components : StableDiffusionXLModularPipeline , state : PipelineState ) -> PipelineState :
13051325 block_state = self .get_block_state (state )
1306-
1326+
13071327 device = components ._execution_device
13081328 dtype = block_state .dtype if block_state .dtype is not None else components .unet .dtype
13091329
13101330 final_batch_size = block_state .batch_size * block_state .num_images_per_prompt
13111331
1312-
13131332 # Optionally get Guidance Scale Embedding for LCM
13141333 block_state .timestep_cond = None
1315-
1316- guidance_scale_tensor = torch .tensor (block_state .embedded_guidance_scale - 1 ).repeat (final_batch_size ).to (device = device )
1334+
1335+ guidance_scale_tensor = (
1336+ torch .tensor (block_state .embedded_guidance_scale - 1 ).repeat (final_batch_size ).to (device = device )
1337+ )
13171338 block_state .timestep_cond = self .get_guidance_scale_embedding (
13181339 guidance_scale_tensor , embedding_dim = components .unet .config .time_cond_proj_dim
13191340 ).to (device = device , dtype = dtype )
@@ -1476,9 +1497,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
14761497 if isinstance (controlnet , MultiControlNetModel ) and isinstance (
14771498 block_state .controlnet_conditioning_scale , float
14781499 ):
1479- block_state .conditioning_scale = [block_state .controlnet_conditioning_scale ] * len (
1480- controlnet .nets
1481- )
1500+ block_state .conditioning_scale = [block_state .controlnet_conditioning_scale ] * len (controlnet .nets )
14821501 else :
14831502 block_state .conditioning_scale = block_state .controlnet_conditioning_scale
14841503
0 commit comments