1- from typing import Optional , Any , Iterable , List
21import copy
32import threading
3+ from typing import Any , Iterable , List , Optional
4+
45import torch
6+
57from diffusers .utils import logging
8+
69from .scheduler import BaseAsyncScheduler , async_retrieve_timesteps
7- from .wrappers import ThreadSafeTokenizerWrapper , ThreadSafeVAEWrapper , ThreadSafeImageProcessorWrapper
10+ from .wrappers import ThreadSafeImageProcessorWrapper , ThreadSafeTokenizerWrapper , ThreadSafeVAEWrapper
11+
812
913logger = logging .get_logger (__name__ )
1014
15+
1116class RequestScopedPipeline :
1217 DEFAULT_MUTABLE_ATTRS = [
1318 "_all_hooks" ,
14- "_offload_device" ,
19+ "_offload_device" ,
1520 "_progress_bar_config" ,
1621 "_progress_bar" ,
1722 "_rng_state" ,
@@ -29,42 +34,39 @@ def __init__(
2934 wrap_scheduler : bool = True ,
3035 ):
3136 self ._base = pipeline
32-
33-
37+
3438 self .unet = getattr (pipeline , "unet" , None )
35- self .vae = getattr (pipeline , "vae" , None )
39+ self .vae = getattr (pipeline , "vae" , None )
3640 self .text_encoder = getattr (pipeline , "text_encoder" , None )
3741 self .components = getattr (pipeline , "components" , None )
38-
42+
3943 self .transformer = getattr (pipeline , "transformer" , None )
40-
41- if wrap_scheduler and hasattr (pipeline , ' scheduler' ) and pipeline .scheduler is not None :
44+
45+ if wrap_scheduler and hasattr (pipeline , " scheduler" ) and pipeline .scheduler is not None :
4246 if not isinstance (pipeline .scheduler , BaseAsyncScheduler ):
4347 pipeline .scheduler = BaseAsyncScheduler (pipeline .scheduler )
4448
4549 self ._mutable_attrs = list (mutable_attrs ) if mutable_attrs is not None else list (self .DEFAULT_MUTABLE_ATTRS )
46-
47-
50+
4851 self ._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading .Lock ()
4952
5053 self ._vae_lock = threading .Lock ()
5154 self ._image_lock = threading .Lock ()
52-
55+
5356 self ._auto_detect_mutables = bool (auto_detect_mutables )
5457 self ._tensor_numel_threshold = int (tensor_numel_threshold )
5558 self ._auto_detected_attrs : List [str ] = []
5659
5760 def _detect_kernel_pipeline (self , pipeline ) -> bool :
5861 kernel_indicators = [
59- ' text_encoding_cache' ,
60- ' memory_manager' ,
61- ' enable_optimizations' ,
62- ' _create_request_context' ,
63- ' get_optimization_stats'
62+ " text_encoding_cache" ,
63+ " memory_manager" ,
64+ " enable_optimizations" ,
65+ " _create_request_context" ,
66+ " get_optimization_stats" ,
6467 ]
65-
66- return any (hasattr (pipeline , attr ) for attr in kernel_indicators )
6768
69+ return any (hasattr (pipeline , attr ) for attr in kernel_indicators )
6870
6971 def _make_local_scheduler (self , num_inference_steps : int , device : Optional [str ] = None , ** clone_kwargs ):
7072 base_sched = getattr (self ._base , "scheduler" , None )
@@ -78,14 +80,12 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
7880
7981 try :
8082 return wrapped_scheduler .clone_for_request (
81- num_inference_steps = num_inference_steps ,
82- device = device ,
83- ** clone_kwargs
83+ num_inference_steps = num_inference_steps , device = device , ** clone_kwargs
8484 )
8585 except Exception as e :
8686 logger .debug (f"clone_for_request failed: { e } ; trying shallow copy fallback" )
8787 try :
88- if hasattr (wrapped_scheduler , ' scheduler' ):
88+ if hasattr (wrapped_scheduler , " scheduler" ):
8989 try :
9090 copied_scheduler = copy .copy (wrapped_scheduler .scheduler )
9191 return BaseAsyncScheduler (copied_scheduler )
@@ -95,8 +95,10 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
9595 copied_scheduler = copy .copy (wrapped_scheduler )
9696 return BaseAsyncScheduler (copied_scheduler )
9797 except Exception as e2 :
98- logger .warning (f"Shallow copy of scheduler also failed: { e2 } . Using original scheduler (*thread-unsafe but functional*)." )
99- return wrapped_scheduler
98+ logger .warning (
99+ f"Shallow copy of scheduler also failed: { e2 } . Using original scheduler (*thread-unsafe but functional*)."
100+ )
101+ return wrapped_scheduler
100102
101103 def _autodetect_mutables (self , max_attrs : int = 40 ):
102104 if not self ._auto_detect_mutables :
@@ -107,16 +109,15 @@ def _autodetect_mutables(self, max_attrs: int = 40):
107109
108110 candidates : List [str ] = []
109111 seen = set ()
110-
111-
112+
112113 for name in dir (self ._base ):
113114 if name .startswith ("__" ):
114115 continue
115116 if name in self ._mutable_attrs :
116117 continue
117118 if name in ("to" , "save_pretrained" , "from_pretrained" ):
118119 continue
119-
120+
120121 try :
121122 val = getattr (self ._base , name )
122123 except Exception :
@@ -165,7 +166,9 @@ def _clone_mutable_attrs(self, base, local):
165166 attrs_to_clone = list (self ._mutable_attrs )
166167 attrs_to_clone .extend (self ._autodetect_mutables ())
167168
168- EXCLUDE_ATTRS = {"components" ,}
169+ EXCLUDE_ATTRS = {
170+ "components" ,
171+ }
169172
170173 for attr in attrs_to_clone :
171174 if attr in EXCLUDE_ATTRS :
@@ -213,16 +216,16 @@ def _clone_mutable_attrs(self, base, local):
213216 def _is_tokenizer_component (self , component ) -> bool :
214217 if component is None :
215218 return False
216-
217- tokenizer_methods = [' encode' , ' decode' , ' tokenize' , ' __call__' ]
219+
220+ tokenizer_methods = [" encode" , " decode" , " tokenize" , " __call__" ]
218221 has_tokenizer_methods = any (hasattr (component , method ) for method in tokenizer_methods )
219-
222+
220223 class_name = component .__class__ .__name__ .lower ()
221- has_tokenizer_in_name = ' tokenizer' in class_name
222-
223- tokenizer_attrs = [' vocab_size' , ' pad_token' , ' eos_token' , ' bos_token' ]
224+ has_tokenizer_in_name = " tokenizer" in class_name
225+
226+ tokenizer_attrs = [" vocab_size" , " pad_token" , " eos_token" , " bos_token" ]
224227 has_tokenizer_attrs = any (hasattr (component , attr ) for attr in tokenizer_attrs )
225-
228+
226229 return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs )
227230
228231 def _should_wrap_tokenizers (self ) -> bool :
@@ -238,11 +241,21 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
238241 local_pipe = copy .deepcopy (self ._base )
239242
240243 try :
241- if hasattr (local_pipe , "vae" ) and local_pipe .vae is not None and not isinstance (local_pipe .vae , ThreadSafeVAEWrapper ):
244+ if (
245+ hasattr (local_pipe , "vae" )
246+ and local_pipe .vae is not None
247+ and not isinstance (local_pipe .vae , ThreadSafeVAEWrapper )
248+ ):
242249 local_pipe .vae = ThreadSafeVAEWrapper (local_pipe .vae , self ._vae_lock )
243250
244- if hasattr (local_pipe , "image_processor" ) and local_pipe .image_processor is not None and not isinstance (local_pipe .image_processor , ThreadSafeImageProcessorWrapper ):
245- local_pipe .image_processor = ThreadSafeImageProcessorWrapper (local_pipe .image_processor , self ._image_lock )
251+ if (
252+ hasattr (local_pipe , "image_processor" )
253+ and local_pipe .image_processor is not None
254+ and not isinstance (local_pipe .image_processor , ThreadSafeImageProcessorWrapper )
255+ ):
256+ local_pipe .image_processor = ThreadSafeImageProcessorWrapper (
257+ local_pipe .image_processor , self ._image_lock
258+ )
246259 except Exception as e :
247260 logger .debug (f"Could not wrap vae/image_processor: { e } " )
248261
@@ -253,7 +266,7 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
253266 num_inference_steps = num_inference_steps ,
254267 device = device ,
255268 return_scheduler = True ,
256- ** {k : v for k , v in kwargs .items () if k in [' timesteps' , ' sigmas' ]}
269+ ** {k : v for k , v in kwargs .items () if k in [" timesteps" , " sigmas" ]},
257270 )
258271
259272 final_scheduler = BaseAsyncScheduler (configured_scheduler )
@@ -262,10 +275,9 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
262275 logger .warning ("Could not set scheduler on local pipe; proceeding without replacing scheduler." )
263276
264277 self ._clone_mutable_attrs (self ._base , local_pipe )
265-
266278
267279 original_tokenizers = {}
268-
280+
269281 if self ._should_wrap_tokenizers ():
270282 try :
271283 for name in dir (local_pipe ):
@@ -281,7 +293,7 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
281293 for key , val in local_pipe .components .items ():
282294 if val is None :
283295 continue
284-
296+
285297 if self ._is_tokenizer_component (val ):
286298 if not isinstance (val , ThreadSafeTokenizerWrapper ):
287299 original_tokenizers [f"components[{ key } ]" ] = val
@@ -293,9 +305,8 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
293305
294306 result = None
295307 cm = getattr (local_pipe , "model_cpu_offload_context" , None )
296-
308+
297309 try :
298-
299310 if callable (cm ):
300311 try :
301312 with cm ():
@@ -316,8 +327,8 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
316327 try :
317328 for name , tok in original_tokenizers .items ():
318329 if name .startswith ("components[" ):
319- key = name [len ("components[" ): - 1 ]
320- if hasattr (local_pipe , ' components' ) and isinstance (local_pipe .components , dict ):
330+ key = name [len ("components[" ) : - 1 ]
331+ if hasattr (local_pipe , " components" ) and isinstance (local_pipe .components , dict ):
321332 local_pipe .components [key ] = tok
322333 else :
323334 setattr (local_pipe , name , tok )
0 commit comments