1717
1818import  PIL 
1919import  torch 
20+ from  collections  import  OrderedDict 
2021
2122from  ...guider  import  CFGGuider 
2223from  ...image_processor  import  VaeImageProcessor 
@@ -122,64 +123,6 @@ def retrieve_latents(
122123        raise  AttributeError ("Could not access latents of provided encoder_output" )
123124
124125
125- class  StableDiffusionXLOutputStep (PipelineBlock ):
126-     model_name  =  "stable-diffusion-xl" 
127- 
128-     @property  
129-     def  inputs (self ) ->  List [Tuple [str , Any ]]:
130-         return  [("return_dict" , True )] 
131- 
132-     @property  
133-     def  intermediates_outputs (self ) ->  List [str ]:
134-         return  ["images" ]
135-     
136-     @torch .no_grad () 
137-     def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
138-         images  =  state .get_intermediate ("images" )
139-         return_dict  =  state .get_input ("return_dict" )
140- 
141-         if  not  return_dict :
142-             output  =  (images ,)
143-         else :
144-             output  =  StableDiffusionXLPipelineOutput (images = images )
145-         state .add_output ("images" , output )
146-         return  pipeline , state 
147- 
148- 
149- class  StableDiffusionXLInpaintOverlayMaskStep (PipelineBlock ):
150-     model_name  =  "stable-diffusion-xl" 
151- 
152-     @property  
153-     def  inputs (self ) ->  List [Tuple [str , Any ]]:
154-         return  [
155-             ("image" , None ),
156-             ("mask_image" , None ),
157-             ("padding_mask_crop" , None ),
158-         ]
159-     
160-     @property  
161-     def  intermediates_inputs (self ) ->  List [str ]:
162-         return  ["crops_coords" , "images" ]
163- 
164-     @property  
165-     def  intermediates_outputs (self ) ->  List [str ]:
166-         return  ["images" ]
167- 
168-     @torch .no_grad () 
169-     def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
170-         original_image  =  state .get_input ("image" )
171-         padding_mask_crop  =  state .get_input ("padding_mask_crop" )
172-         mask_image  =  state .get_input ("mask_image" )
173-         images  =  state .get_intermediate ("images" )
174-         crops_coords  =  state .get_intermediate ("crops_coords" )
175- 
176-         if  padding_mask_crop  is  not None  and  crops_coords  is  not None :
177-             images  =  [pipeline .image_processor .apply_overlay (mask_image , original_image , i , crops_coords ) for  i  in  images ]
178- 
179-         state .add_intermediate ("images" , images )
180- 
181-         return  pipeline , state 
182- 
183126
184127class  StableDiffusionXLInputStep (PipelineBlock ):
185128    model_name  =  "stable-diffusion-xl" 
@@ -376,7 +319,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
376319        return  pipeline , state 
377320
378321
379- class  StableDiffusionXLVAEEncoderStep (PipelineBlock ):
322+ class  StableDiffusionXLVaeEncoderStep (PipelineBlock ):
380323    expected_components  =  ["vae" ]
381324    model_name  =  "stable-diffusion-xl" 
382325
@@ -589,7 +532,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
589532        return  pipeline , state 
590533
591534
592- class  StableDiffusionXLInpaintVaeEncodeStep (PipelineBlock ):
535+ class  StableDiffusionXLInpaintVaeEncoderStep (PipelineBlock ):
593536    expected_components  =  ["vae" ]
594537    model_name  =  "stable-diffusion-xl" 
595538
@@ -694,7 +637,6 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
694637        return  pipeline , state 
695638
696639
697- # inpaint-specific 
698640class  StableDiffusionXLInpaintPrepareLatentsStep (PipelineBlock ):
699641    expected_components  =  ["scheduler" ]
700642    model_name  =  "stable-diffusion-xl" 
@@ -804,27 +746,22 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
804746    @property  
805747    def  inputs (self ) ->  List [Tuple [str , Any ]]:
806748        return  [
807-             ("height" , None ),
808-             ("width" , None ),
809749            ("generator" , None ),
810750            ("latents" , None ),
811751            ("num_images_per_prompt" , 1 ),
812-             ("image" , None ),
813752            ("denoising_start" , None ),
814753        ]
815754
816755    @property  
817756    def  intermediates_inputs (self ) ->  List [str ]:
818-         return  ["batch_size" , "dtype" , "latent_timestep" ]
757+         return  ["batch_size" , "dtype" , "latent_timestep" ,  "image_latents" ]
819758
820759    @property  
821760    def  intermediates_outputs (self ) ->  List [str ]:
822761        return  ["latents" ]
823762
824763    def  __init__ (self ):
825764        super ().__init__ ()
826-         self .auxiliaries ["image_processor" ] =  VaeImageProcessor ()
827-         self .components ["vae" ] =  None 
828765        self .components ["scheduler" ] =  None 
829766
830767    @torch .no_grad () 
@@ -834,24 +771,22 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
834771        generator  =  state .get_input ("generator" )
835772
836773        # image to image only 
837-         image  =  state .get_input ("image" )
838774        denoising_start  =  state .get_input ("denoising_start" )
839775
840776        batch_size  =  state .get_intermediate ("batch_size" )
841777        dtype  =  state .get_intermediate ("dtype" )
842778        # image to image only 
843779        latent_timestep  =  state .get_intermediate ("latent_timestep" )
780+         image_latents  =  state .get_intermediate ("image_latents" )
844781
845782        if  dtype  is  None :
846783            dtype  =  pipeline .vae .dtype 
847784
848785        device  =  pipeline ._execution_device 
849- 
850-         image  =  pipeline .image_processor .preprocess (image )
851786        add_noise  =  True  if  denoising_start  is  None  else  False 
852787        if  latents  is  None :
853788            latents  =  pipeline .prepare_latents_img2img (
854-                 image ,
789+                 image_latents ,
855790                latent_timestep ,
856791                batch_size ,
857792                num_images_per_prompt ,
@@ -1723,6 +1658,81 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
17231658        return  pipeline , state 
17241659
17251660
1661+ class  StableDiffusionXLInpaintOverlayMaskStep (PipelineBlock ):
1662+     model_name  =  "stable-diffusion-xl" 
1663+ 
1664+     @property  
1665+     def  inputs (self ) ->  List [Tuple [str , Any ]]:
1666+         return  [
1667+             ("image" , None ),
1668+             ("mask_image" , None ),
1669+             ("padding_mask_crop" , None ),
1670+         ]
1671+     
1672+     @property  
1673+     def  intermediates_inputs (self ) ->  List [str ]:
1674+         return  ["crops_coords" , "images" ]
1675+ 
1676+     @property  
1677+     def  intermediates_outputs (self ) ->  List [str ]:
1678+         return  ["images" ]
1679+ 
1680+     @torch .no_grad () 
1681+     def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
1682+         original_image  =  state .get_input ("image" )
1683+         padding_mask_crop  =  state .get_input ("padding_mask_crop" )
1684+         mask_image  =  state .get_input ("mask_image" )
1685+         images  =  state .get_intermediate ("images" )
1686+         crops_coords  =  state .get_intermediate ("crops_coords" )
1687+ 
1688+         if  padding_mask_crop  is  not None  and  crops_coords  is  not None :
1689+             images  =  [pipeline .image_processor .apply_overlay (mask_image , original_image , i , crops_coords ) for  i  in  images ]
1690+ 
1691+         state .add_intermediate ("images" , images )
1692+ 
1693+         return  pipeline , state 
1694+ 
1695+ 
1696+ class  StableDiffusionXLOutputStep (PipelineBlock ):
1697+     model_name  =  "stable-diffusion-xl" 
1698+ 
1699+     @property  
1700+     def  inputs (self ) ->  List [Tuple [str , Any ]]:
1701+         return  [("return_dict" , True )] 
1702+ 
1703+     @property  
1704+     def  intermediates_outputs (self ) ->  List [str ]:
1705+         return  ["images" ]
1706+     
1707+     @torch .no_grad () 
1708+     def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
1709+         images  =  state .get_intermediate ("images" )
1710+         return_dict  =  state .get_input ("return_dict" )
1711+ 
1712+         if  not  return_dict :
1713+             output  =  (images ,)
1714+         else :
1715+             output  =  StableDiffusionXLPipelineOutput (images = images )
1716+         state .add_output ("images" , output )
1717+         return  pipeline , state 
1718+ 
1719+ 
1720+ class  StableDiffusionXLDecodeStep (SequentialPipelineBlocks ):
1721+     block_classes  =  [StableDiffusionXLDecodeLatentsStep , StableDiffusionXLOutputStep ]
1722+     block_names  =  ["decode" , "output" ]
1723+ 
1724+ 
1725+ class  StableDiffusionXLInpaintDecodeStep (SequentialPipelineBlocks ):
1726+     block_classes  =  [StableDiffusionXLDecodeLatentsStep , StableDiffusionXLInpaintOverlayMaskStep , StableDiffusionXLOutputStep ]
1727+     block_names  =  ["decode" , "mask_overlay" , "output" ]
1728+ 
1729+ 
1730+ class  StableDiffusionXLAutoVaeEncoderStep (AutoPipelineBlocks ):
1731+     block_classes  =  [StableDiffusionXLInpaintVaeEncoderStep , StableDiffusionXLVaeEncoderStep ]
1732+     block_names  =  ["inpaint" , "img2img" ]
1733+     block_trigger_inputs  =  ["mask_image" , "image" ]
1734+ 
1735+ 
17261736class  StableDiffusionXLAutoSetTimestepsStep (AutoPipelineBlocks ):
17271737    block_classes  =  [StableDiffusionXLImg2ImgSetTimestepsStep , StableDiffusionXLSetTimestepsStep ]
17281738    block_names  =  ["img2img" , "text2img" ]
@@ -1750,38 +1760,67 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
17501760    block_trigger_inputs  =  ["control_image" , None ]
17511761
17521762
1753- class  StableDiffusionXLDecodeStep (SequentialPipelineBlocks ):
1754-     block_classes  =  [StableDiffusionXLDecodeLatentsStep , StableDiffusionXLOutputStep ]
1755-     block_names  =  ["decode" , "output" ]
1756- 
1757- class  StableDiffusionXLInpaintDecodeStep (SequentialPipelineBlocks ):
1758-     block_classes  =  [StableDiffusionXLDecodeLatentsStep , StableDiffusionXLInpaintOverlayMaskStep , StableDiffusionXLOutputStep ]
1759-     block_names  =  ["decode" , "mask_overlay" , "output" ]
1760- 
17611763class  StableDiffusionXLAutoDecodeStep (AutoPipelineBlocks ):
17621764    block_classes  =  [StableDiffusionXLInpaintDecodeStep , StableDiffusionXLDecodeStep ]
17631765    block_names  =  ["inpaint" , "non-inpaint" ]
17641766    block_trigger_inputs  =  ["padding_mask_crop" , None ]
17651767
1766- class  StableDiffusionXLAllSteps (SequentialPipelineBlocks ):
1767-     block_classes  =  [
1768-         StableDiffusionXLInputStep ,
1769-         StableDiffusionXLTextEncoderStep ,
1770-         StableDiffusionXLAutoSetTimestepsStep ,
1771-         StableDiffusionXLAutoPrepareLatentsStep ,
1772-         StableDiffusionXLAutoPrepareAdditionalConditioningStep ,
1773-         StableDiffusionXLAutoDenoiseStep ,
1774-         StableDiffusionXLAutoDecodeStep 
1775-     ]
1776-     block_names  =  [
1777-         "input" ,
1778-         "text_encoder" ,
1779-         "set_timesteps" ,
1780-         "prepare_latents" ,
1781-         "prepare_add_cond" ,
1782-         "denoise" ,
1783-         "decode" 
1784-     ]
1768+ 
1769+ TEXT2IMAGE_BLOCKS  =  OrderedDict ([
1770+     ("input" , StableDiffusionXLInputStep ),
1771+     ("text_encoder" , StableDiffusionXLTextEncoderStep ),
1772+     ("set_timesteps" , StableDiffusionXLAutoSetTimestepsStep ),
1773+     ("prepare_latents" , StableDiffusionXLAutoPrepareLatentsStep ),
1774+     ("prepare_add_cond" , StableDiffusionXLAutoPrepareAdditionalConditioningStep ),
1775+     ("denoise" , StableDiffusionXLAutoDenoiseStep ),
1776+     ("decode" , StableDiffusionXLDecodeStep )
1777+ ])
1778+ 
1779+ IMAGE2IMAGE_BLOCKS  =  OrderedDict ([
1780+     ("input" , StableDiffusionXLInputStep ),
1781+     ("text_encoder" , StableDiffusionXLTextEncoderStep ),
1782+     ("image_encoder" , StableDiffusionXLVaeEncoderStep ),
1783+     ("set_timesteps" , StableDiffusionXLImg2ImgSetTimestepsStep ),
1784+     ("prepare_latents" , StableDiffusionXLImg2ImgPrepareLatentsStep ),
1785+     ("prepare_add_cond" , StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep ),
1786+     ("denoise" , StableDiffusionXLDenoiseStep ),
1787+     ("decode" , StableDiffusionXLDecodeStep )
1788+ ])
1789+ 
1790+ INPAINT_BLOCKS  =  OrderedDict ([
1791+     ("input" , StableDiffusionXLInputStep ),
1792+     ("text_encoder" , StableDiffusionXLTextEncoderStep ),
1793+     ("image_encoder" , StableDiffusionXLInpaintVaeEncoderStep ),
1794+     ("set_timesteps" , StableDiffusionXLImg2ImgSetTimestepsStep ),
1795+     ("prepare_latents" , StableDiffusionXLInpaintPrepareLatentsStep ),
1796+     ("prepare_add_cond" , StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep ),
1797+     ("denoise" , StableDiffusionXLDenoiseStep ),
1798+     ("decode" , StableDiffusionXLInpaintDecodeStep )
1799+ ])
1800+ 
1801+ CONTROLNET_BLOCKS  =  OrderedDict ([
1802+     ("denoise" , StableDiffusionXLControlNetDenoiseStep ),
1803+ ])
1804+ 
1805+ AUTO_BLOCKS  =  OrderedDict ([
1806+     ("input" , StableDiffusionXLInputStep ),
1807+     ("text_encoder" , StableDiffusionXLTextEncoderStep ),
1808+     ("image_encoder" , StableDiffusionXLAutoVaeEncoderStep ),
1809+     ("set_timesteps" , StableDiffusionXLAutoSetTimestepsStep ),
1810+     ("prepare_latents" , StableDiffusionXLAutoPrepareLatentsStep ),
1811+     ("prepare_add_cond" , StableDiffusionXLAutoPrepareAdditionalConditioningStep ),
1812+     ("denoise" , StableDiffusionXLAutoDenoiseStep ),
1813+     ("decode" , StableDiffusionXLAutoDecodeStep )
1814+ ])
1815+ 
1816+ 
1817+ SDXL_SUPPORTED_BLOCKS  =  {
1818+     "text2img" : TEXT2IMAGE_BLOCKS ,
1819+     "img2img" : IMAGE2IMAGE_BLOCKS ,
1820+     "inpaint" : INPAINT_BLOCKS ,
1821+     "controlnet" : CONTROLNET_BLOCKS ,
1822+     "auto" : AUTO_BLOCKS 
1823+ }
17851824
17861825
17871826class  StableDiffusionXLModularPipeline (
0 commit comments