Skip to content

Commit f2e9f02

Browse files
Add thread-safe wrappers for components in pipeline
Refactor requestscopedpipeline.py to add thread-safe wrappers for tokenizer, VAE, and image processor. Introduce locking mechanisms to ensure thread safety during concurrent access.
1 parent c91e6f4 commit f2e9f02

File tree

1 file changed

+173
-69
lines changed

1 file changed

+173
-69
lines changed

examples/server-async/utils/requestscopedpipeline.py

Lines changed: 173 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,92 @@
1+
from typing import Optional, Any, Iterable, List
12
import copy
23
import threading
3-
from typing import Any, Iterable, List, Optional
4-
54
import torch
6-
75
from diffusers.utils import logging
8-
96
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
107

11-
128
logger = logging.get_logger(__name__)
139

10+
class ThreadSafeTokenizerWrapper:
11+
def __init__(self, tokenizer, lock):
12+
self._tokenizer = tokenizer
13+
self._lock = lock
1414

15-
def safe_tokenize(tokenizer, *args, lock, **kwargs):
16-
with lock:
17-
return tokenizer(*args, **kwargs)
18-
15+
self._thread_safe_methods = {
16+
'__call__', 'encode', 'decode', 'tokenize',
17+
'encode_plus', 'batch_encode_plus', 'batch_decode'
18+
}
19+
20+
def __getattr__(self, name):
21+
attr = getattr(self._tokenizer, name)
22+
23+
if name in self._thread_safe_methods and callable(attr):
24+
def wrapped_method(*args, **kwargs):
25+
with self._lock:
26+
return attr(*args, **kwargs)
27+
return wrapped_method
28+
29+
return attr
30+
31+
def __call__(self, *args, **kwargs):
32+
with self._lock:
33+
return self._tokenizer(*args, **kwargs)
34+
35+
def __setattr__(self, name, value):
36+
if name.startswith('_'):
37+
super().__setattr__(name, value)
38+
else:
39+
setattr(self._tokenizer, name, value)
40+
41+
def __dir__(self):
42+
return dir(self._tokenizer)
43+
44+
45+
class ThreadSafeVAEWrapper:
46+
def __init__(self, vae, lock):
47+
self._vae = vae
48+
self._lock = lock
49+
50+
def __getattr__(self, name):
51+
attr = getattr(self._vae, name)
52+
# métodos que queremos proteger
53+
if name in {"decode", "encode", "forward"} and callable(attr):
54+
def wrapped(*args, **kwargs):
55+
with self._lock:
56+
return attr(*args, **kwargs)
57+
return wrapped
58+
return attr
59+
60+
def __setattr__(self, name, value):
61+
if name.startswith("_"):
62+
super().__setattr__(name, value)
63+
else:
64+
setattr(self._vae, name, value)
65+
66+
class ThreadSafeImageProcessorWrapper:
67+
def __init__(self, proc, lock):
68+
self._proc = proc
69+
self._lock = lock
70+
71+
def __getattr__(self, name):
72+
attr = getattr(self._proc, name)
73+
if name in {"postprocess", "preprocess"} and callable(attr):
74+
def wrapped(*args, **kwargs):
75+
with self._lock:
76+
return attr(*args, **kwargs)
77+
return wrapped
78+
return attr
79+
80+
def __setattr__(self, name, value):
81+
if name.startswith("_"):
82+
super().__setattr__(name, value)
83+
else:
84+
setattr(self._proc, name, value)
1985

2086
class RequestScopedPipeline:
2187
DEFAULT_MUTABLE_ATTRS = [
2288
"_all_hooks",
23-
"_offload_device",
89+
"_offload_device",
2490
"_progress_bar_config",
2591
"_progress_bar",
2692
"_rng_state",
@@ -38,23 +104,43 @@ def __init__(
38104
wrap_scheduler: bool = True,
39105
):
40106
self._base = pipeline
107+
108+
41109
self.unet = getattr(pipeline, "unet", None)
42-
self.vae = getattr(pipeline, "vae", None)
110+
self.vae = getattr(pipeline, "vae", None)
43111
self.text_encoder = getattr(pipeline, "text_encoder", None)
44112
self.components = getattr(pipeline, "components", None)
45-
46-
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
113+
114+
self.transformer = getattr(pipeline, "transformer", None)
115+
116+
if wrap_scheduler and hasattr(pipeline, 'scheduler') and pipeline.scheduler is not None:
47117
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
48118
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
49119

50120
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
121+
122+
51123
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
52124

125+
self._vae_lock = threading.Lock()
126+
self._image_lock = threading.Lock()
127+
53128
self._auto_detect_mutables = bool(auto_detect_mutables)
54129
self._tensor_numel_threshold = int(tensor_numel_threshold)
55-
56130
self._auto_detected_attrs: List[str] = []
57131

132+
def _detect_kernel_pipeline(self, pipeline) -> bool:
133+
kernel_indicators = [
134+
'text_encoding_cache',
135+
'memory_manager',
136+
'enable_optimizations',
137+
'_create_request_context',
138+
'get_optimization_stats'
139+
]
140+
141+
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
142+
143+
58144
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
59145
base_sched = getattr(self._base, "scheduler", None)
60146
if base_sched is None:
@@ -67,15 +153,25 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
67153

68154
try:
69155
return wrapped_scheduler.clone_for_request(
70-
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
156+
num_inference_steps=num_inference_steps,
157+
device=device,
158+
**clone_kwargs
71159
)
72160
except Exception as e:
73-
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
161+
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
74162
try:
75-
return copy.deepcopy(wrapped_scheduler)
76-
except Exception as e:
77-
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
78-
return wrapped_scheduler
163+
if hasattr(wrapped_scheduler, 'scheduler'):
164+
try:
165+
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
166+
return BaseAsyncScheduler(copied_scheduler)
167+
except Exception:
168+
return wrapped_scheduler
169+
else:
170+
copied_scheduler = copy.copy(wrapped_scheduler)
171+
return BaseAsyncScheduler(copied_scheduler)
172+
except Exception as e2:
173+
logger.warning(f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*).")
174+
return wrapped_scheduler
79175

80176
def _autodetect_mutables(self, max_attrs: int = 40):
81177
if not self._auto_detect_mutables:
@@ -86,25 +182,26 @@ def _autodetect_mutables(self, max_attrs: int = 40):
86182

87183
candidates: List[str] = []
88184
seen = set()
185+
186+
89187
for name in dir(self._base):
90188
if name.startswith("__"):
91189
continue
92190
if name in self._mutable_attrs:
93191
continue
94192
if name in ("to", "save_pretrained", "from_pretrained"):
95193
continue
194+
96195
try:
97196
val = getattr(self._base, name)
98197
except Exception:
99198
continue
100199

101200
import types
102201

103-
# skip callables and modules
104202
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
105203
continue
106204

107-
# containers -> candidate
108205
if isinstance(val, (dict, list, set, tuple, bytearray)):
109206
candidates.append(name)
110207
seen.add(name)
@@ -143,9 +240,7 @@ def _clone_mutable_attrs(self, base, local):
143240
attrs_to_clone = list(self._mutable_attrs)
144241
attrs_to_clone.extend(self._autodetect_mutables())
145242

146-
EXCLUDE_ATTRS = {
147-
"components",
148-
}
243+
EXCLUDE_ATTRS = {"components",}
149244

150245
for attr in attrs_to_clone:
151246
if attr in EXCLUDE_ATTRS:
@@ -193,18 +288,21 @@ def _clone_mutable_attrs(self, base, local):
193288
def _is_tokenizer_component(self, component) -> bool:
194289
if component is None:
195290
return False
196-
197-
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
291+
292+
tokenizer_methods = ['encode', 'decode', 'tokenize', '__call__']
198293
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
199-
294+
200295
class_name = component.__class__.__name__.lower()
201-
has_tokenizer_in_name = "tokenizer" in class_name
202-
203-
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
296+
has_tokenizer_in_name = 'tokenizer' in class_name
297+
298+
tokenizer_attrs = ['vocab_size', 'pad_token', 'eos_token', 'bos_token']
204299
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
205-
300+
206301
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
207302

303+
def _should_wrap_tokenizers(self) -> bool:
304+
return True
305+
208306
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
209307
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
210308

@@ -214,14 +312,23 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
214312
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
215313
local_pipe = copy.deepcopy(self._base)
216314

315+
try:
316+
if hasattr(local_pipe, "vae") and local_pipe.vae is not None and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper):
317+
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
318+
319+
if hasattr(local_pipe, "image_processor") and local_pipe.image_processor is not None and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper):
320+
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(local_pipe.image_processor, self._image_lock)
321+
except Exception as e:
322+
logger.debug(f"Could not wrap vae/image_processor: {e}")
323+
217324
if local_scheduler is not None:
218325
try:
219326
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
220327
local_scheduler.scheduler,
221328
num_inference_steps=num_inference_steps,
222329
device=device,
223330
return_scheduler=True,
224-
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
331+
**{k: v for k, v in kwargs.items() if k in ['timesteps', 'sigmas']}
225332
)
226333

227334
final_scheduler = BaseAsyncScheduler(configured_scheduler)
@@ -230,67 +337,64 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
230337
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
231338

232339
self._clone_mutable_attrs(self._base, local_pipe)
340+
233341

234-
# 4) wrap tokenizers on the local pipe with the lock wrapper
235-
tokenizer_wrappers = {} # name -> original_tokenizer
236-
try:
237-
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
238-
for name in dir(local_pipe):
239-
if "tokenizer" in name and not name.startswith("_"):
240-
tok = getattr(local_pipe, name, None)
241-
if tok is not None and self._is_tokenizer_component(tok):
242-
tokenizer_wrappers[name] = tok
243-
setattr(
244-
local_pipe,
245-
name,
246-
lambda *args, tok=tok, **kwargs: safe_tokenize(
247-
tok, *args, lock=self._tokenizer_lock, **kwargs
248-
),
249-
)
250-
251-
# b) wrap tokenizers in components dict
252-
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
253-
for key, val in local_pipe.components.items():
254-
if val is None:
255-
continue
256-
257-
if self._is_tokenizer_component(val):
258-
tokenizer_wrappers[f"components[{key}]"] = val
259-
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
260-
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
261-
)
342+
original_tokenizers = {}
343+
344+
if self._should_wrap_tokenizers():
345+
try:
346+
for name in dir(local_pipe):
347+
if "tokenizer" in name and not name.startswith("_"):
348+
tok = getattr(local_pipe, name, None)
349+
if tok is not None and self._is_tokenizer_component(tok):
350+
if not isinstance(tok, ThreadSafeTokenizerWrapper):
351+
original_tokenizers[name] = tok
352+
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
353+
setattr(local_pipe, name, wrapped_tokenizer)
354+
355+
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
356+
for key, val in local_pipe.components.items():
357+
if val is None:
358+
continue
359+
360+
if self._is_tokenizer_component(val):
361+
if not isinstance(val, ThreadSafeTokenizerWrapper):
362+
original_tokenizers[f"components[{key}]"] = val
363+
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
364+
local_pipe.components[key] = wrapped_tokenizer
262365

263-
except Exception as e:
264-
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
366+
except Exception as e:
367+
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
265368

266369
result = None
267370
cm = getattr(local_pipe, "model_cpu_offload_context", None)
371+
268372
try:
373+
269374
if callable(cm):
270375
try:
271376
with cm():
272377
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
273378
except TypeError:
274-
# cm might be a context manager instance rather than callable
275379
try:
276380
with cm:
277381
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
278382
except Exception as e:
279383
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
280384
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
281385
else:
282-
# no offload context available — call directly
283386
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
284387

285388
return result
286389

287390
finally:
288391
try:
289-
for name, tok in tokenizer_wrappers.items():
392+
for name, tok in original_tokenizers.items():
290393
if name.startswith("components["):
291-
key = name[len("components[") : -1]
292-
local_pipe.components[key] = tok
394+
key = name[len("components["):-1]
395+
if hasattr(local_pipe, 'components') and isinstance(local_pipe.components, dict):
396+
local_pipe.components[key] = tok
293397
else:
294398
setattr(local_pipe, name, tok)
295399
except Exception as e:
296-
logger.debug(f"Error restoring wrapped tokenizers: {e}")
400+
logger.debug(f"Error restoring original tokenizers: {e}")

0 commit comments

Comments
 (0)