1+ from typing import Optional , Any , Iterable , List
2+ import copy
3+ import threading
4+ import torch
5+ from diffusers .utils import logging
6+
7+ logger = logging .get_logger (__name__ )
8+
9+ def safe_tokenize (tokenizer , * args , lock , ** kwargs ):
10+ with lock :
11+ return tokenizer (* args , ** kwargs )
12+
13+ class RequestScopedPipeline :
14+ DEFAULT_MUTABLE_ATTRS = [
15+ "_all_hooks" ,
16+ "_offload_device" ,
17+ "_progress_bar_config" ,
18+ "_progress_bar" ,
19+ "_rng_state" ,
20+ "_last_seed" ,
21+ "latents" ,
22+ ]
23+
24+ def __init__ (
25+ self ,
26+ pipeline : Any ,
27+ mutable_attrs : Optional [Iterable [str ]] = None ,
28+ auto_detect_mutables : bool = True ,
29+ tensor_numel_threshold : int = 1_000_000 ,
30+ tokenizer_lock : Optional [threading .Lock ] = None
31+ ):
32+ self ._base = pipeline
33+ self .unet = getattr (pipeline , "unet" , None )
34+ self .vae = getattr (pipeline , "vae" , None )
35+ self .text_encoder = getattr (pipeline , "text_encoder" , None )
36+ self .components = getattr (pipeline , "components" , None )
37+
38+ self ._mutable_attrs = list (mutable_attrs ) if mutable_attrs is not None else list (self .DEFAULT_MUTABLE_ATTRS )
39+ self ._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading .Lock ()
40+
41+ self ._auto_detect_mutables = bool (auto_detect_mutables )
42+ self ._tensor_numel_threshold = int (tensor_numel_threshold )
43+
44+ self ._auto_detected_attrs : List [str ] = []
45+
46+ def _make_local_scheduler (self , num_inference_steps : int , device : Optional [str ] = None , ** clone_kwargs ):
47+ base_sched = getattr (self ._base , "scheduler" , None )
48+ if base_sched is None :
49+ return None
50+
51+ if hasattr (base_sched , "clone_for_request" ):
52+ try :
53+ return base_sched .clone_for_request (num_inference_steps = num_inference_steps , device = device , ** clone_kwargs )
54+ except Exception as e :
55+ logger .debug (f"clone_for_request failed: { e } ; falling back to deepcopy()" )
56+
57+ try :
58+ return copy .deepcopy (base_sched )
59+ except Exception as e :
60+ logger .warning (f"Deepcopy of scheduler failed: { e } . Returning original scheduler (*risky*)." )
61+ return base_sched
62+
63+ def _autodetect_mutables (self , max_attrs : int = 40 ):
64+ if not self ._auto_detect_mutables :
65+ return []
66+
67+ if self ._auto_detected_attrs :
68+ return self ._auto_detected_attrs
69+
70+ candidates : List [str ] = []
71+ seen = set ()
72+ for name in dir (self ._base ):
73+ if name .startswith ("__" ):
74+ continue
75+ if name in self ._mutable_attrs :
76+ continue
77+ if name in ("to" , "save_pretrained" , "from_pretrained" ):
78+ continue
79+ try :
80+ val = getattr (self ._base , name )
81+ except Exception :
82+ continue
83+
84+ import types
85+
86+ # skip callables and modules
87+ if callable (val ) or isinstance (val , (types .ModuleType , types .FunctionType , types .MethodType )):
88+ continue
89+
90+ # containers -> candidate
91+ if isinstance (val , (dict , list , set , tuple , bytearray )):
92+ candidates .append (name )
93+ seen .add (name )
94+ else :
95+ # try Tensor detection
96+ try :
97+ if isinstance (val , torch .Tensor ):
98+ if val .numel () <= self ._tensor_numel_threshold :
99+ candidates .append (name )
100+ seen .add (name )
101+ else :
102+ logger .debug (f"Ignoring large tensor attr '{ name } ', numel={ val .numel ()} " )
103+ except Exception :
104+ continue
105+
106+ if len (candidates ) >= max_attrs :
107+ break
108+
109+ self ._auto_detected_attrs = candidates
110+ logger .debug (f"Autodetected mutable attrs to clone: { self ._auto_detected_attrs } " )
111+ return self ._auto_detected_attrs
112+
113+ def _is_readonly_property (self , base_obj , attr_name : str ) -> bool :
114+ try :
115+ cls = type (base_obj )
116+ descriptor = getattr (cls , attr_name , None )
117+ if isinstance (descriptor , property ):
118+ return descriptor .fset is None
119+ if hasattr (descriptor , "__set__" ) is False and descriptor is not None :
120+ return False
121+ except Exception :
122+ pass
123+ return False
124+
125+ def _clone_mutable_attrs (self , base , local ):
126+ attrs_to_clone = list (self ._mutable_attrs )
127+ attrs_to_clone .extend (self ._autodetect_mutables ())
128+
129+ EXCLUDE_ATTRS = {"components" ,}
130+
131+ for attr in attrs_to_clone :
132+ if attr in EXCLUDE_ATTRS :
133+ logger .debug (f"Skipping excluded attr '{ attr } '" )
134+ continue
135+ if not hasattr (base , attr ):
136+ continue
137+ if self ._is_readonly_property (base , attr ):
138+ logger .debug (f"Skipping read-only property '{ attr } '" )
139+ continue
140+
141+ try :
142+ val = getattr (base , attr )
143+ except Exception as e :
144+ logger .debug (f"Could not getattr('{ attr } ') on base pipeline: { e } " )
145+ continue
146+
147+ try :
148+ if isinstance (val , dict ):
149+ setattr (local , attr , dict (val ))
150+ elif isinstance (val , (list , tuple , set )):
151+ setattr (local , attr , list (val ))
152+ elif isinstance (val , bytearray ):
153+ setattr (local , attr , bytearray (val ))
154+ else :
155+ # small tensors or atomic values
156+ if isinstance (val , torch .Tensor ):
157+ if val .numel () <= self ._tensor_numel_threshold :
158+ setattr (local , attr , val .clone ())
159+ else :
160+ # don't clone big tensors, keep reference
161+ setattr (local , attr , val )
162+ else :
163+ try :
164+ setattr (local , attr , copy .copy (val ))
165+ except Exception :
166+ setattr (local , attr , val )
167+ except (AttributeError , TypeError ) as e :
168+ logger .debug (f"Skipping cloning attribute '{ attr } ' because it is not settable: { e } " )
169+ continue
170+ except Exception as e :
171+ logger .debug (f"Unexpected error cloning attribute '{ attr } ': { e } " )
172+ continue
173+
174+ def _is_tokenizer_component (self , component ) -> bool :
175+ if component is None :
176+ return False
177+
178+ tokenizer_methods = ['encode' , 'decode' , 'tokenize' , '__call__' ]
179+ has_tokenizer_methods = any (hasattr (component , method ) for method in tokenizer_methods )
180+
181+ class_name = component .__class__ .__name__ .lower ()
182+ has_tokenizer_in_name = 'tokenizer' in class_name
183+
184+ tokenizer_attrs = ['vocab_size' , 'pad_token' , 'eos_token' , 'bos_token' ]
185+ has_tokenizer_attrs = any (hasattr (component , attr ) for attr in tokenizer_attrs )
186+
187+ return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs )
188+
189+ def generate (self , * args , num_inference_steps : int = 50 , device : Optional [str ] = None , ** kwargs ):
190+ local_scheduler = self ._make_local_scheduler (num_inference_steps = num_inference_steps , device = device )
191+
192+ try :
193+ local_pipe = copy .copy (self ._base )
194+ except Exception as e :
195+ logger .warning (f"copy.copy(self._base) failed: { e } . Falling back to deepcopy (may increase memory)." )
196+ local_pipe = copy .deepcopy (self ._base )
197+
198+ if local_scheduler is not None :
199+ try :
200+ setattr (local_pipe , "scheduler" , local_scheduler )
201+ except Exception :
202+ logger .warning ("Could not set scheduler on local pipe; proceeding without replacing scheduler." )
203+
204+ self ._clone_mutable_attrs (self ._base , local_pipe )
205+
206+ # 4) wrap tokenizers on the local pipe with the lock wrapper
207+ tokenizer_wrappers = {} # name -> original_tokenizer
208+ try :
209+ # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
210+ for name in dir (local_pipe ):
211+ if "tokenizer" in name and not name .startswith ("_" ):
212+ tok = getattr (local_pipe , name , None )
213+ if tok is not None and self ._is_tokenizer_component (tok ):
214+ tokenizer_wrappers [name ] = tok
215+ setattr (
216+ local_pipe ,
217+ name ,
218+ lambda * args , tok = tok , ** kwargs : safe_tokenize (tok , * args , lock = self ._tokenizer_lock , ** kwargs )
219+ )
220+
221+ # b) wrap tokenizers in components dict
222+ if hasattr (local_pipe , "components" ) and isinstance (local_pipe .components , dict ):
223+ for key , val in local_pipe .components .items ():
224+ if val is None :
225+ continue
226+
227+ if self ._is_tokenizer_component (val ):
228+ tokenizer_wrappers [f"components[{ key } ]" ] = val
229+ local_pipe .components [key ] = lambda * args , tokenizer = val , ** kwargs : safe_tokenize (
230+ tokenizer , * args , lock = self ._tokenizer_lock , ** kwargs
231+ )
232+
233+ except Exception as e :
234+ logger .debug (f"Tokenizer wrapping step encountered an error: { e } " )
235+
236+ result = None
237+ cm = getattr (local_pipe , "model_cpu_offload_context" , None )
238+ try :
239+ if callable (cm ):
240+ try :
241+ with cm ():
242+ result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
243+ except TypeError :
244+ # cm might be a context manager instance rather than callable
245+ try :
246+ with cm :
247+ result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
248+ except Exception as e :
249+ logger .debug (f"model_cpu_offload_context usage failed: { e } . Proceeding without it." )
250+ result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
251+ else :
252+ # no offload context available — call directly
253+ result = local_pipe (* args , num_inference_steps = num_inference_steps , ** kwargs )
254+
255+ return result
256+
257+ finally :
258+ try :
259+ for name , tok in tokenizer_wrappers .items ():
260+ if name .startswith ("components[" ):
261+ key = name [len ("components[" ):- 1 ]
262+ local_pipe .components [key ] = tok
263+ else :
264+ setattr (local_pipe , name , tok )
265+ except Exception as e :
266+ logger .debug (f"Error restoring wrapped tokenizers: { e } " )
0 commit comments