@@ -71,9 +71,10 @@ def retrieve_timesteps(
7171    num_inference_steps : Optional [int ] =  None ,
7272    device : Optional [Union [str , torch .device ]] =  None ,
7373    timesteps : Optional [List [int ]] =  None ,
74+     sigmas : Optional [List [float ]] =  None ,
7475    ** kwargs ,
7576):
76-     """ 
77+     r """
7778    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 
7879    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 
7980
@@ -86,14 +87,18 @@ def retrieve_timesteps(
8687        device (`str` or `torch.device`, *optional*): 
8788            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 
8889        timesteps (`List[int]`, *optional*): 
89-                 Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default 
90-                 timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` 
91-                 must be `None`. 
90+             Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 
91+             `num_inference_steps` and `sigmas` must be `None`. 
92+         sigmas (`List[float]`, *optional*): 
93+             Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 
94+             `num_inference_steps` and `timesteps` must be `None`. 
9295
9396    Returns: 
9497        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 
9598        second element is the number of inference steps. 
9699    """ 
100+     if  timesteps  is  not None  and  sigmas  is  not None :
101+         raise  ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
97102    if  timesteps  is  not None :
98103        accepts_timesteps  =  "timesteps"  in  set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
99104        if  not  accepts_timesteps :
@@ -104,6 +109,16 @@ def retrieve_timesteps(
104109        scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
105110        timesteps  =  scheduler .timesteps 
106111        num_inference_steps  =  len (timesteps )
112+     elif  sigmas  is  not None :
113+         accept_sigmas  =  "sigmas"  in  set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
114+         if  not  accept_sigmas :
115+             raise  ValueError (
116+                 f"The current scheduler class { scheduler .__class__ }  
117+                 f" sigmas schedules. Please check whether you are using the correct scheduler." 
118+             )
119+         scheduler .set_timesteps (sigmas = sigmas , device = device , ** kwargs )
120+         timesteps  =  scheduler .timesteps 
121+         num_inference_steps  =  len (timesteps )
107122    else :
108123        scheduler .set_timesteps (num_inference_steps , device = device , ** kwargs )
109124        timesteps  =  scheduler .timesteps 
@@ -458,14 +473,12 @@ def _clean_caption(self, caption):
458473        caption  =  re .sub ("<person>" , "person" , caption )
459474        # urls: 
460475        caption  =  re .sub (
461-             r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))" ,
462-             # noqa 
476+             r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))" ,  # noqa 
463477            "" ,
464478            caption ,
465479        )  # regex for urls 
466480        caption  =  re .sub (
467-             r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))" ,
468-             # noqa 
481+             r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))" ,  # noqa 
469482            "" ,
470483            caption ,
471484        )  # regex for urls 
@@ -488,13 +501,12 @@ def _clean_caption(self, caption):
488501        caption  =  re .sub (r"[\u3300-\u33ff]+" , "" , caption )
489502        caption  =  re .sub (r"[\u3400-\u4dbf]+" , "" , caption )
490503        caption  =  re .sub (r"[\u4dc0-\u4dff]+" , "" , caption )
491-         #  caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
504+         caption  =  re .sub (r"[\u4e00-\u9fff]+" , "" , caption )
492505        ####################################################### 
493506
494507        # все виды тире / all types of dash --> "-" 
495508        caption  =  re .sub (
496-             r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+" ,
497-             # noqa 
509+             r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+" ,  # noqa 
498510            "-" ,
499511            caption ,
500512        )
@@ -565,6 +577,7 @@ def _clean_caption(self, caption):
565577        caption  =  re .sub (r"^[\'\_,\-\:;]" , r"" , caption )
566578        caption  =  re .sub (r"[\'\_,\-\:\-\+]$" , r"" , caption )
567579        caption  =  re .sub (r"^\.\S+$" , "" , caption )
580+ 
568581        return  caption .strip ()
569582
570583    def  prepare_latents (
0 commit comments