@@ -1582,6 +1582,293 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
15821582        return  pipeline , state 
15831583
15841584
1585+ class  StableDiffusionXLControlNetUnionDenoiseStep (PipelineBlock ):
1586+     expected_components  =  ["unet" , "controlnet" , "scheduler" , "guider" , "controlnet_guider" ]
1587+     model_name  =  "stable-diffusion-xl" 
1588+ 
1589+     @property  
1590+     def  inputs (self ) ->  List [Tuple [str , Any ]]:
1591+         return  [
1592+             ("control_image" , None ),
1593+             ("control_guidance_start" , 0.0 ),
1594+             ("control_guidance_end" , 1.0 ),
1595+             ("controlnet_conditioning_scale" , 1.0 ),
1596+             ("control_mode" , 0 ),
1597+             ("guess_mode" , False ),
1598+             ("num_images_per_prompt" , 1 ),
1599+             ("guidance_scale" , 5.0 ),
1600+             ("guidance_rescale" , 0.0 ),
1601+             ("cross_attention_kwargs" , None ),
1602+             ("generator" , None ),
1603+             ("eta" , 0.0 ),
1604+             ("guider_kwargs" , None ),
1605+         ]
1606+ 
1607+     @property  
1608+     def  intermediates_inputs (self ) ->  List [str ]:
1609+         return  [
1610+             "latents" ,
1611+             "batch_size" ,
1612+             "timesteps" ,
1613+             "num_inference_steps" ,
1614+             "prompt_embeds" ,
1615+             "negative_prompt_embeds" ,
1616+             "add_time_ids" ,
1617+             "negative_add_time_ids" ,
1618+             "pooled_prompt_embeds" ,
1619+             "negative_pooled_prompt_embeds" ,
1620+             "timestep_cond" ,
1621+             "mask" ,
1622+             "noise" ,
1623+             "image_latents" ,
1624+             "crops_coords" ,
1625+         ]
1626+ 
1627+     @property  
1628+     def  intermediates_outputs (self ) ->  List [str ]:
1629+         return  ["latents" ]
1630+ 
1631+     def  __init__ (self ):
1632+         super ().__init__ ()
1633+         self .components ["guider" ] =  CFGGuider ()
1634+         self .components ["controlnet_guider" ] =  CFGGuider ()
1635+         self .components ["scheduler" ] =  None 
1636+         self .components ["unet" ] =  None 
1637+         self .components ["controlnet" ] =  None 
1638+         control_image_processor  =  VaeImageProcessor (do_convert_rgb = True , do_normalize = False )
1639+         self .auxiliaries ["control_image_processor" ] =  control_image_processor 
1640+ 
1641+     @torch .no_grad () 
1642+     def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
1643+         guidance_scale  =  state .get_input ("guidance_scale" )
1644+         guidance_rescale  =  state .get_input ("guidance_rescale" )
1645+         cross_attention_kwargs  =  state .get_input ("cross_attention_kwargs" )
1646+         guider_kwargs  =  state .get_input ("guider_kwargs" )
1647+         generator  =  state .get_input ("generator" )
1648+         eta  =  state .get_input ("eta" )
1649+         num_images_per_prompt  =  state .get_input ("num_images_per_prompt" )
1650+         # controlnet-specific inputs 
1651+         control_image  =  state .get_input ("control_image" )
1652+         control_guidance_start  =  state .get_input ("control_guidance_start" )
1653+         control_guidance_end  =  state .get_input ("control_guidance_end" )
1654+         controlnet_conditioning_scale  =  state .get_input ("controlnet_conditioning_scale" )
1655+         control_mode  =  state .get_input ("control_mode" )
1656+         guess_mode  =  state .get_input ("guess_mode" )
1657+ 
1658+         batch_size  =  state .get_intermediate ("batch_size" )
1659+         latents  =  state .get_intermediate ("latents" )
1660+         timesteps  =  state .get_intermediate ("timesteps" )
1661+         num_inference_steps  =  state .get_intermediate ("num_inference_steps" )
1662+ 
1663+         prompt_embeds  =  state .get_intermediate ("prompt_embeds" )
1664+         negative_prompt_embeds  =  state .get_intermediate ("negative_prompt_embeds" )
1665+         pooled_prompt_embeds  =  state .get_intermediate ("pooled_prompt_embeds" )
1666+         negative_pooled_prompt_embeds  =  state .get_intermediate ("negative_pooled_prompt_embeds" )
1667+         add_time_ids  =  state .get_intermediate ("add_time_ids" )
1668+         negative_add_time_ids  =  state .get_intermediate ("negative_add_time_ids" )
1669+ 
1670+         timestep_cond  =  state .get_intermediate ("timestep_cond" )
1671+ 
1672+         # inpainting 
1673+         mask  =  state .get_intermediate ("mask" )
1674+         noise  =  state .get_intermediate ("noise" )
1675+         image_latents  =  state .get_intermediate ("image_latents" )
1676+         crops_coords  =  state .get_intermediate ("crops_coords" )
1677+ 
1678+         device  =  pipeline ._execution_device 
1679+ 
1680+         height , width  =  latents .shape [- 2 :]
1681+         height  =  height  *  pipeline .vae_scale_factor 
1682+         width  =  width  *  pipeline .vae_scale_factor 
1683+ 
1684+         # prepare controlnet inputs 
1685+         controlnet  =  pipeline .controlnet ._orig_mod  if  is_compiled_module (pipeline .controlnet ) else  pipeline .controlnet 
1686+ 
1687+         # align format for control guidance 
1688+         if  not  isinstance (control_guidance_start , list ) and  isinstance (control_guidance_end , list ):
1689+             control_guidance_start  =  len (control_guidance_end ) *  [control_guidance_start ]
1690+         elif  not  isinstance (control_guidance_end , list ) and  isinstance (control_guidance_start , list ):
1691+             control_guidance_end  =  len (control_guidance_start ) *  [control_guidance_end ]
1692+ 
1693+         global_pool_conditions  =  controlnet .config .global_pool_conditions 
1694+         guess_mode  =  guess_mode  or  global_pool_conditions 
1695+ 
1696+         num_control_type  =  controlnet .config .num_control_type 
1697+ 
1698+         if  not  isinstance (control_image , list ):
1699+             control_image  =  [control_image ]
1700+ 
1701+         if  not  isinstance (control_mode , list ):
1702+             control_mode  =  [control_mode ]
1703+ 
1704+         if  len (control_image ) !=  len (control_mode ):
1705+             raise  ValueError ("Expected len(control_image) == len(control_type)" )
1706+ 
1707+         control_type  =  [0  for  _  in  range (num_control_type )]
1708+         for  control_idx  in  control_mode :
1709+             control_type [control_idx ] =  1 
1710+ 
1711+         control_type  =  torch .Tensor (control_type )
1712+ 
1713+         for  idx , _  in  enumerate (control_image ):
1714+             control_image [idx ] =  pipeline .prepare_control_image (
1715+                 image = control_image [idx ],
1716+                 width = width ,
1717+                 height = height ,
1718+                 batch_size = batch_size  *  num_images_per_prompt ,
1719+                 num_images_per_prompt = num_images_per_prompt ,
1720+                 device = device ,
1721+                 dtype = controlnet .dtype ,
1722+                 crops_coords = crops_coords ,
1723+             )
1724+             height , width  =  control_image [idx ].shape [- 2 :]
1725+ 
1726+         controlnet_keep  =  []
1727+         for  i  in  range (len (timesteps )):
1728+             controlnet_keep .append (
1729+                 1.0 
1730+                 -  float (i  /  len (timesteps ) <  control_guidance_start  or  (i  +  1 ) /  len (timesteps ) >  control_guidance_end )
1731+             )
1732+ 
1733+         # Prepare conditional inputs for unet using the guider 
1734+         # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale 
1735+         disable_guidance  =  True  if  pipeline .unet .config .time_cond_proj_dim  is  not None  else  False 
1736+         guider_kwargs  =  guider_kwargs  or  {}
1737+         guider_kwargs  =  {
1738+             ** guider_kwargs ,
1739+             "disable_guidance" : disable_guidance ,
1740+             "guidance_scale" : guidance_scale ,
1741+             "guidance_rescale" : guidance_rescale ,
1742+             "batch_size" : batch_size ,
1743+         }
1744+         pipeline .guider .set_guider (pipeline , guider_kwargs )
1745+         prompt_embeds  =  pipeline .guider .prepare_input (
1746+             prompt_embeds ,
1747+             negative_prompt_embeds ,
1748+         )
1749+         add_time_ids  =  pipeline .guider .prepare_input (
1750+             add_time_ids ,
1751+             negative_add_time_ids ,
1752+         )
1753+         pooled_prompt_embeds  =  pipeline .guider .prepare_input (
1754+             pooled_prompt_embeds ,
1755+             negative_pooled_prompt_embeds ,
1756+         )
1757+ 
1758+         added_cond_kwargs  =  {
1759+             "text_embeds" : pooled_prompt_embeds ,
1760+             "time_ids" : add_time_ids ,
1761+         }
1762+ 
1763+         # Prepare conditional inputs for controlnet using the guider 
1764+         controlnet_disable_guidance  =  True  if  disable_guidance  or  guess_mode  else  False 
1765+         controlnet_guider_kwargs  =  guider_kwargs  or  {}
1766+         controlnet_guider_kwargs  =  {
1767+             ** controlnet_guider_kwargs ,
1768+             "disable_guidance" : controlnet_disable_guidance ,
1769+             "guidance_scale" : guidance_scale ,
1770+             "guidance_rescale" : guidance_rescale ,
1771+             "batch_size" : batch_size ,
1772+         }
1773+         pipeline .controlnet_guider .set_guider (pipeline , controlnet_guider_kwargs )
1774+         controlnet_prompt_embeds  =  pipeline .controlnet_guider .prepare_input (prompt_embeds )
1775+         controlnet_added_cond_kwargs  =  {
1776+             "text_embeds" : pipeline .controlnet_guider .prepare_input (pooled_prompt_embeds ),
1777+             "time_ids" : pipeline .controlnet_guider .prepare_input (add_time_ids ),
1778+         }
1779+         for  idx , _  in  enumerate (control_image ):
1780+             control_image [idx ] =  pipeline .controlnet_guider .prepare_input (control_image [idx ], control_image [idx ])
1781+ 
1782+         # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 
1783+         extra_step_kwargs  =  pipeline .prepare_extra_step_kwargs (generator , eta )
1784+         num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  pipeline .scheduler .order , 0 )
1785+ 
1786+         control_type  =  (
1787+             control_type .reshape (1 , - 1 )
1788+             .to (device , dtype = prompt_embeds .dtype )
1789+             .repeat (batch_size  *  num_images_per_prompt  *  2 , 1 )
1790+         )
1791+         with  pipeline .progress_bar (total = num_inference_steps ) as  progress_bar :
1792+             for  i , t  in  enumerate (timesteps ):
1793+                 # prepare latents for unet using the guider 
1794+                 latent_model_input  =  pipeline .guider .prepare_input (latents , latents )
1795+ 
1796+                 # prepare latents for controlnet using the guider 
1797+                 control_model_input  =  pipeline .controlnet_guider .prepare_input (latents , latents )
1798+ 
1799+                 if  isinstance (controlnet_keep [i ], list ):
1800+                     cond_scale  =  [c  *  s  for  c , s  in  zip (controlnet_conditioning_scale , controlnet_keep [i ])]
1801+                 else :
1802+                     controlnet_cond_scale  =  controlnet_conditioning_scale 
1803+                     if  isinstance (controlnet_cond_scale , list ):
1804+                         controlnet_cond_scale  =  controlnet_cond_scale [0 ]
1805+                     cond_scale  =  controlnet_cond_scale  *  controlnet_keep [i ]
1806+ 
1807+                 down_block_res_samples , mid_block_res_sample  =  pipeline .controlnet (
1808+                     pipeline .scheduler .scale_model_input (control_model_input , t ),
1809+                     t ,
1810+                     encoder_hidden_states = controlnet_prompt_embeds ,
1811+                     controlnet_cond = control_image ,
1812+                     control_type = control_type ,
1813+                     control_type_idx = control_mode ,
1814+                     conditioning_scale = cond_scale ,
1815+                     guess_mode = guess_mode ,
1816+                     added_cond_kwargs = controlnet_added_cond_kwargs ,
1817+                     return_dict = False ,
1818+                 )
1819+ 
1820+                 # when we apply guidance for unet, but not for controlnet: 
1821+                 # add 0 to the unconditional batch 
1822+                 down_block_res_samples  =  pipeline .guider .prepare_input (
1823+                     down_block_res_samples , [torch .zeros_like (d ) for  d  in  down_block_res_samples ]
1824+                 )
1825+                 mid_block_res_sample  =  pipeline .guider .prepare_input (
1826+                     mid_block_res_sample , torch .zeros_like (mid_block_res_sample )
1827+                 )
1828+ 
1829+                 latent_model_input  =  pipeline .scheduler .scale_model_input (latent_model_input , t )
1830+ 
1831+                 noise_pred  =  pipeline .unet (
1832+                     latent_model_input ,
1833+                     t ,
1834+                     encoder_hidden_states = prompt_embeds ,
1835+                     timestep_cond = timestep_cond ,
1836+                     cross_attention_kwargs = cross_attention_kwargs ,
1837+                     added_cond_kwargs = added_cond_kwargs ,
1838+                     down_block_additional_residuals = down_block_res_samples ,
1839+                     mid_block_additional_residual = mid_block_res_sample ,
1840+                     return_dict = False ,
1841+                 )[0 ]
1842+                 # perform guidance 
1843+                 noise_pred  =  pipeline .guider .apply_guidance (noise_pred , timestep = t , latents = latents )
1844+                 # compute the previous noisy sample x_t -> x_t-1 
1845+                 latents_dtype  =  latents .dtype 
1846+                 latents  =  pipeline .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
1847+                 if  latents .dtype  !=  latents_dtype :
1848+                     if  torch .backends .mps .is_available ():
1849+                         # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 
1850+                         latents  =  latents .to (latents_dtype )
1851+ 
1852+                 if  mask  is  not None  and  image_latents  is  not None :
1853+                     init_mask  =  pipeline .guider ._maybe_split_prepared_input (mask )[0 ]
1854+                     init_latents_proper  =  image_latents 
1855+                     if  i  <  len (timesteps ) -  1 :
1856+                         noise_timestep  =  timesteps [i  +  1 ]
1857+                         init_latents_proper  =  pipeline .scheduler .add_noise (
1858+                             init_latents_proper , noise , torch .tensor ([noise_timestep ])
1859+                         )
1860+ 
1861+                     latents  =  (1  -  init_mask ) *  init_latents_proper  +  init_mask  *  latents 
1862+ 
1863+                 if  i  ==  len (timesteps ) -  1  or  ((i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  pipeline .scheduler .order  ==  0 ):
1864+                     progress_bar .update ()
1865+ 
1866+         pipeline .guider .reset_guider (pipeline )
1867+         pipeline .controlnet_guider .reset_guider (pipeline )
1868+         state .add_intermediate ("latents" , latents )
1869+ 
1870+         return  pipeline , state 
1871+ 
15851872class  StableDiffusionXLDecodeLatentsStep (PipelineBlock ):
15861873    expected_components  =  ["vae" ]
15871874    model_name  =  "stable-diffusion-xl" 
0 commit comments