1+ from typing import Optional , Any , Iterable , List
12import copy
23import threading
3- from typing import Any , Iterable , List , Optional
4-
54import torch
6-
75from diffusers .utils import logging
8-
96from .scheduler import BaseAsyncScheduler , async_retrieve_timesteps
107
11-
128logger = logging .get_logger (__name__ )
139
10+ class ThreadSafeTokenizerWrapper :
11+ def __init__ (self , tokenizer , lock ):
12+ self ._tokenizer = tokenizer
13+ self ._lock = lock
1414
15- def safe_tokenize (tokenizer , * args , lock , ** kwargs ):
16- with lock :
17- return tokenizer (* args , ** kwargs )
18-
15+ self ._thread_safe_methods = {
16+ '__call__' , 'encode' , 'decode' , 'tokenize' ,
17+ 'encode_plus' , 'batch_encode_plus' , 'batch_decode'
18+ }
19+
20+ def __getattr__ (self , name ):
21+ attr = getattr (self ._tokenizer , name )
22+
23+ if name in self ._thread_safe_methods and callable (attr ):
24+ def wrapped_method (* args , ** kwargs ):
25+ with self ._lock :
26+ return attr (* args , ** kwargs )
27+ return wrapped_method
28+
29+ return attr
30+
31+ def __call__ (self , * args , ** kwargs ):
32+ with self ._lock :
33+ return self ._tokenizer (* args , ** kwargs )
34+
35+ def __setattr__ (self , name , value ):
36+ if name .startswith ('_' ):
37+ super ().__setattr__ (name , value )
38+ else :
39+ setattr (self ._tokenizer , name , value )
40+
41+ def __dir__ (self ):
42+ return dir (self ._tokenizer )
43+
44+
45+ class ThreadSafeVAEWrapper :
46+ def __init__ (self , vae , lock ):
47+ self ._vae = vae
48+ self ._lock = lock
49+
50+ def __getattr__ (self , name ):
51+ attr = getattr (self ._vae , name )
52+ # métodos que queremos proteger
53+ if name in {"decode" , "encode" , "forward" } and callable (attr ):
54+ def wrapped (* args , ** kwargs ):
55+ with self ._lock :
56+ return attr (* args , ** kwargs )
57+ return wrapped
58+ return attr
59+
60+ def __setattr__ (self , name , value ):
61+ if name .startswith ("_" ):
62+ super ().__setattr__ (name , value )
63+ else :
64+ setattr (self ._vae , name , value )
65+
66+ class ThreadSafeImageProcessorWrapper :
67+ def __init__ (self , proc , lock ):
68+ self ._proc = proc
69+ self ._lock = lock
70+
71+ def __getattr__ (self , name ):
72+ attr = getattr (self ._proc , name )
73+ if name in {"postprocess" , "preprocess" } and callable (attr ):
74+ def wrapped (* args , ** kwargs ):
75+ with self ._lock :
76+ return attr (* args , ** kwargs )
77+ return wrapped
78+ return attr
79+
80+ def __setattr__ (self , name , value ):
81+ if name .startswith ("_" ):
82+ super ().__setattr__ (name , value )
83+ else :
84+ setattr (self ._proc , name , value )
1985
2086class RequestScopedPipeline :
2187 DEFAULT_MUTABLE_ATTRS = [
2288 "_all_hooks" ,
23- "_offload_device" ,
89+ "_offload_device" ,
2490 "_progress_bar_config" ,
2591 "_progress_bar" ,
2692 "_rng_state" ,
@@ -38,23 +104,43 @@ def __init__(
38104 wrap_scheduler : bool = True ,
39105 ):
40106 self ._base = pipeline
107+
108+
41109 self .unet = getattr (pipeline , "unet" , None )
42- self .vae = getattr (pipeline , "vae" , None )
110+ self .vae = getattr (pipeline , "vae" , None )
43111 self .text_encoder = getattr (pipeline , "text_encoder" , None )
44112 self .components = getattr (pipeline , "components" , None )
45-
46- if wrap_scheduler and hasattr (pipeline , "scheduler" ) and pipeline .scheduler is not None :
113+
114+ self .transformer = getattr (pipeline , "transformer" , None )
115+
116+ if wrap_scheduler and hasattr (pipeline , 'scheduler' ) and pipeline .scheduler is not None :
47117 if not isinstance (pipeline .scheduler , BaseAsyncScheduler ):
48118 pipeline .scheduler = BaseAsyncScheduler (pipeline .scheduler )
49119
50120 self ._mutable_attrs = list (mutable_attrs ) if mutable_attrs is not None else list (self .DEFAULT_MUTABLE_ATTRS )
121+
122+
51123 self ._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading .Lock ()
52124
125+ self ._vae_lock = threading .Lock ()
126+ self ._image_lock = threading .Lock ()
127+
53128 self ._auto_detect_mutables = bool (auto_detect_mutables )
54129 self ._tensor_numel_threshold = int (tensor_numel_threshold )
55-
56130 self ._auto_detected_attrs : List [str ] = []
57131
132+ def _detect_kernel_pipeline (self , pipeline ) -> bool :
133+ kernel_indicators = [
134+ 'text_encoding_cache' ,
135+ 'memory_manager' ,
136+ 'enable_optimizations' ,
137+ '_create_request_context' ,
138+ 'get_optimization_stats'
139+ ]
140+
141+ return any (hasattr (pipeline , attr ) for attr in kernel_indicators )
142+
143+
58144 def _make_local_scheduler (self , num_inference_steps : int , device : Optional [str ] = None , ** clone_kwargs ):
59145 base_sched = getattr (self ._base , "scheduler" , None )
60146 if base_sched is None :
@@ -67,15 +153,25 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
67153
68154 try :
69155 return wrapped_scheduler .clone_for_request (
70- num_inference_steps = num_inference_steps , device = device , ** clone_kwargs
156+ num_inference_steps = num_inference_steps ,
157+ device = device ,
158+ ** clone_kwargs
71159 )
72160 except Exception as e :
73- logger .debug (f"clone_for_request failed: { e } ; falling back to deepcopy() " )
161+ logger .debug (f"clone_for_request failed: { e } ; trying shallow copy fallback " )
74162 try :
75- return copy .deepcopy (wrapped_scheduler )
76- except Exception as e :
77- logger .warning (f"Deepcopy of scheduler failed: { e } . Returning original scheduler (*risky*)." )
78- return wrapped_scheduler
163+ if hasattr (wrapped_scheduler , 'scheduler' ):
164+ try :
165+ copied_scheduler = copy .copy (wrapped_scheduler .scheduler )
166+ return BaseAsyncScheduler (copied_scheduler )
167+ except Exception :
168+ return wrapped_scheduler
169+ else :
170+ copied_scheduler = copy .copy (wrapped_scheduler )
171+ return BaseAsyncScheduler (copied_scheduler )
172+ except Exception as e2 :
173+ logger .warning (f"Shallow copy of scheduler also failed: { e2 } . Using original scheduler (*thread-unsafe but functional*)." )
174+ return wrapped_scheduler
79175
80176 def _autodetect_mutables (self , max_attrs : int = 40 ):
81177 if not self ._auto_detect_mutables :
@@ -86,25 +182,26 @@ def _autodetect_mutables(self, max_attrs: int = 40):
86182
87183 candidates : List [str ] = []
88184 seen = set ()
185+
186+
89187 for name in dir (self ._base ):
90188 if name .startswith ("__" ):
91189 continue
92190 if name in self ._mutable_attrs :
93191 continue
94192 if name in ("to" , "save_pretrained" , "from_pretrained" ):
95193 continue
194+
96195 try :
97196 val = getattr (self ._base , name )
98197 except Exception :
99198 continue
100199
101200 import types
102201
103- # skip callables and modules
104202 if callable (val ) or isinstance (val , (types .ModuleType , types .FunctionType , types .MethodType )):
105203 continue
106204
107- # containers -> candidate
108205 if isinstance (val , (dict , list , set , tuple , bytearray )):
109206 candidates .append (name )
110207 seen .add (name )
@@ -143,9 +240,7 @@ def _clone_mutable_attrs(self, base, local):
143240 attrs_to_clone = list (self ._mutable_attrs )
144241 attrs_to_clone .extend (self ._autodetect_mutables ())
145242
146- EXCLUDE_ATTRS = {
147- "components" ,
148- }
243+ EXCLUDE_ATTRS = {"components" ,}
149244
150245 for attr in attrs_to_clone :
151246 if attr in EXCLUDE_ATTRS :
@@ -193,18 +288,21 @@ def _clone_mutable_attrs(self, base, local):
193288 def _is_tokenizer_component (self , component ) -> bool :
194289 if component is None :
195290 return False
196-
197- tokenizer_methods = [" encode" , " decode" , " tokenize" , " __call__" ]
291+
292+ tokenizer_methods = [' encode' , ' decode' , ' tokenize' , ' __call__' ]
198293 has_tokenizer_methods = any (hasattr (component , method ) for method in tokenizer_methods )
199-
294+
200295 class_name = component .__class__ .__name__ .lower ()
201- has_tokenizer_in_name = " tokenizer" in class_name
202-
203- tokenizer_attrs = [" vocab_size" , " pad_token" , " eos_token" , " bos_token" ]
296+ has_tokenizer_in_name = ' tokenizer' in class_name
297+
298+ tokenizer_attrs = [' vocab_size' , ' pad_token' , ' eos_token' , ' bos_token' ]
204299 has_tokenizer_attrs = any (hasattr (component , attr ) for attr in tokenizer_attrs )
205-
300+
206301 return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs )
207302
303+ def _should_wrap_tokenizers (self ) -> bool :
304+ return True
305+
208306 def generate (self , * args , num_inference_steps : int = 50 , device : Optional [str ] = None , ** kwargs ):
209307 local_scheduler = self ._make_local_scheduler (num_inference_steps = num_inference_steps , device = device )
210308
@@ -214,14 +312,23 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
214312 logger .warning (f"copy.copy(self._base) failed: { e } . Falling back to deepcopy (may increase memory)." )
215313 local_pipe = copy .deepcopy (self ._base )
216314
315+ try :
316+ if hasattr (local_pipe , "vae" ) and local_pipe .vae is not None and not isinstance (local_pipe .vae , ThreadSafeVAEWrapper ):
317+ local_pipe .vae = ThreadSafeVAEWrapper (local_pipe .vae , self ._vae_lock )
318+
319+ if hasattr (local_pipe , "image_processor" ) and local_pipe .image_processor is not None and not isinstance (local_pipe .image_processor , ThreadSafeImageProcessorWrapper ):
320+ local_pipe .image_processor = ThreadSafeImageProcessorWrapper (local_pipe .image_processor , self ._image_lock )
321+ except Exception as e :
322+ logger .debug (f"Could not wrap vae/image_processor: { e } " )
323+
217324 if local_scheduler is not None :
218325 try :
219326 timesteps , num_steps , configured_scheduler = async_retrieve_timesteps (
220327 local_scheduler .scheduler ,
221328 num_inference_steps = num_inference_steps ,
222329 device = device ,
223330 return_scheduler = True ,
224- ** {k : v for k , v in kwargs .items () if k in [" timesteps" , " sigmas" ]},
331+ ** {k : v for k , v in kwargs .items () if k in [' timesteps' , ' sigmas' ]}
225332 )
226333
227334 final_scheduler = BaseAsyncScheduler (configured_scheduler )
@@ -230,67 +337,64 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
230337 logger .warning ("Could not set scheduler on local pipe; proceeding without replacing scheduler." )
231338
232339 self ._clone_mutable_attrs (self ._base , local_pipe )
340+
233341
234- # 4) wrap tokenizers on the local pipe with the lock wrapper
235- tokenizer_wrappers = {} # name -> original_tokenizer
236- try :
237- # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
238- for name in dir (local_pipe ):
239- if "tokenizer" in name and not name .startswith ("_" ):
240- tok = getattr (local_pipe , name , None )
241- if tok is not None and self ._is_tokenizer_component (tok ):
242- tokenizer_wrappers [name ] = tok
243- setattr (
244- local_pipe ,
245- name ,
246- lambda * args , tok = tok , ** kwargs : safe_tokenize (
247- tok , * args , lock = self ._tokenizer_lock , ** kwargs
248- ),
249- )
250-
251- # b) wrap tokenizers in components dict
252- if hasattr (local_pipe , "components" ) and isinstance (local_pipe .components , dict ):
253- for key , val in local_pipe .components .items ():
254- if val is None :
255- continue
256-
257- if self ._is_tokenizer_component (val ):
258- tokenizer_wrappers [f"components[{ key } ]" ] = val
259- local_pipe .components [key ] = lambda * args , tokenizer = val , ** kwargs : safe_tokenize (
260- tokenizer , * args , lock = self ._tokenizer_lock , ** kwargs
261- )
342+ original_tokenizers = {}
343+
344+ if self ._should_wrap_tokenizers ():
345+ try :
346+ for name in dir (local_pipe ):
347+ if "tokenizer" in name and not name .startswith ("_" ):
348+ tok = getattr (local_pipe , name , None )
349+ if tok is not None and self ._is_tokenizer_component (tok ):
350+ if not isinstance (tok , ThreadSafeTokenizerWrapper ):
351+ original_tokenizers [name ] = tok
352+ wrapped_tokenizer = ThreadSafeTokenizerWrapper (tok , self ._tokenizer_lock )
353+ setattr (local_pipe , name , wrapped_tokenizer )
354+
355+ if hasattr (local_pipe , "components" ) and isinstance (local_pipe .components , dict ):
356+ for key , val in local_pipe .components .items ():
357+ if val is None :
358+ continue
359+
360+ if self ._is_tokenizer_component (val ):
361+ if not isinstance (val , ThreadSafeTokenizerWrapper ):
362+ original_tokenizers [f"components[{ key } ]" ] = val
363+ wrapped_tokenizer = ThreadSafeTokenizerWrapper (val , self ._tokenizer_lock )
364+ local_pipe .components [key ] = wrapped_tokenizer
262365
263- except Exception as e :
264- logger .debug (f"Tokenizer wrapping step encountered an error: { e } " )
366+ except Exception as e :
367+ logger .debug (f"Tokenizer wrapping step encountered an error: { e } " )
265368
266369 result = None
267370 cm = getattr (local_pipe , "model_cpu_offload_context" , None )
371+
268372 try :
373+
269374 if callable (cm ):
270375 try :
271376 with cm ():
272377 result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
273378 except TypeError :
274- # cm might be a context manager instance rather than callable
275379 try :
276380 with cm :
277381 result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
278382 except Exception as e :
279383 logger .debug (f"model_cpu_offload_context usage failed: { e } . Proceeding without it." )
280384 result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
281385 else :
282- # no offload context available — call directly
283386 result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
284387
285388 return result
286389
287390 finally :
288391 try :
289- for name , tok in tokenizer_wrappers .items ():
392+ for name , tok in original_tokenizers .items ():
290393 if name .startswith ("components[" ):
291- key = name [len ("components[" ) : - 1 ]
292- local_pipe .components [key ] = tok
394+ key = name [len ("components[" ):- 1 ]
395+ if hasattr (local_pipe , 'components' ) and isinstance (local_pipe .components , dict ):
396+ local_pipe .components [key ] = tok
293397 else :
294398 setattr (local_pipe , name , tok )
295399 except Exception as e :
296- logger .debug (f"Error restoring wrapped tokenizers: { e } " )
400+ logger .debug (f"Error restoring original tokenizers: { e } " )
0 commit comments