33import threading
44import torch
55from diffusers .utils import logging
6+ from .scheduler import BaseAsyncScheduler , async_retrieve_timesteps
7+
68
79logger = logging .get_logger (__name__ )
810
@@ -27,14 +29,19 @@ def __init__(
2729 mutable_attrs : Optional [Iterable [str ]] = None ,
2830 auto_detect_mutables : bool = True ,
2931 tensor_numel_threshold : int = 1_000_000 ,
30- tokenizer_lock : Optional [threading .Lock ] = None
32+ tokenizer_lock : Optional [threading .Lock ] = None ,
33+ wrap_scheduler : bool = True
3134 ):
3235 self ._base = pipeline
3336 self .unet = getattr (pipeline , "unet" , None )
3437 self .vae = getattr (pipeline , "vae" , None )
3538 self .text_encoder = getattr (pipeline , "text_encoder" , None )
3639 self .components = getattr (pipeline , "components" , None )
3740
41+ if wrap_scheduler and hasattr (pipeline , 'scheduler' ) and pipeline .scheduler is not None :
42+ if not isinstance (pipeline .scheduler , BaseAsyncScheduler ):
43+ pipeline .scheduler = BaseAsyncScheduler (pipeline .scheduler )
44+
3845 self ._mutable_attrs = list (mutable_attrs ) if mutable_attrs is not None else list (self .DEFAULT_MUTABLE_ATTRS )
3946 self ._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading .Lock ()
4047
@@ -48,17 +55,24 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
4855 if base_sched is None :
4956 return None
5057
51- if hasattr (base_sched , "clone_for_request" ):
52- try :
53- return base_sched .clone_for_request (num_inference_steps = num_inference_steps , device = device , ** clone_kwargs )
54- except Exception as e :
55- logger .debug (f"clone_for_request failed: { e } ; falling back to deepcopy()" )
58+ if not isinstance (base_sched , BaseAsyncScheduler ):
59+ wrapped_scheduler = BaseAsyncScheduler (base_sched )
60+ else :
61+ wrapped_scheduler = base_sched
5662
5763 try :
58- return copy .deepcopy (base_sched )
64+ return wrapped_scheduler .clone_for_request (
65+ num_inference_steps = num_inference_steps ,
66+ device = device ,
67+ ** clone_kwargs
68+ )
5969 except Exception as e :
60- logger .warning (f"Deepcopy of scheduler failed: { e } . Returning original scheduler (*risky*)." )
61- return base_sched
70+ logger .debug (f"clone_for_request failed: { e } ; falling back to deepcopy()" )
71+ try :
72+ return copy .deepcopy (wrapped_scheduler )
73+ except Exception as e :
74+ logger .warning (f"Deepcopy of scheduler failed: { e } . Returning original scheduler (*risky*)." )
75+ return wrapped_scheduler
6276
6377 def _autodetect_mutables (self , max_attrs : int = 40 ):
6478 if not self ._auto_detect_mutables :
@@ -197,7 +211,16 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
197211
198212 if local_scheduler is not None :
199213 try :
200- setattr (local_pipe , "scheduler" , local_scheduler )
214+ timesteps , num_steps , configured_scheduler = async_retrieve_timesteps (
215+ local_scheduler .scheduler ,
216+ num_inference_steps = num_inference_steps ,
217+ device = device ,
218+ return_scheduler = True ,
219+ ** {k : v for k , v in kwargs .items () if k in ['timesteps' , 'sigmas' ]}
220+ )
221+
222+ final_scheduler = BaseAsyncScheduler (configured_scheduler )
223+ setattr (local_pipe , "scheduler" , final_scheduler )
201224 except Exception :
202225 logger .warning ("Could not set scheduler on local pipe; proceeding without replacing scheduler." )
203226
0 commit comments