Skip to content

Commit 489da5d

Browse files
Add wrappers.py
1 parent f2e9f02 commit 489da5d

File tree

2 files changed

+75
-76
lines changed

2 files changed

+75
-76
lines changed

examples/server-async/utils/requestscopedpipeline.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,85 +4,10 @@
44
import torch
55
from diffusers.utils import logging
66
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
7+
from .wrappers import ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper, ThreadSafeImageProcessorWrapper
78

89
logger = logging.get_logger(__name__)
910

10-
class ThreadSafeTokenizerWrapper:
11-
def __init__(self, tokenizer, lock):
12-
self._tokenizer = tokenizer
13-
self._lock = lock
14-
15-
self._thread_safe_methods = {
16-
'__call__', 'encode', 'decode', 'tokenize',
17-
'encode_plus', 'batch_encode_plus', 'batch_decode'
18-
}
19-
20-
def __getattr__(self, name):
21-
attr = getattr(self._tokenizer, name)
22-
23-
if name in self._thread_safe_methods and callable(attr):
24-
def wrapped_method(*args, **kwargs):
25-
with self._lock:
26-
return attr(*args, **kwargs)
27-
return wrapped_method
28-
29-
return attr
30-
31-
def __call__(self, *args, **kwargs):
32-
with self._lock:
33-
return self._tokenizer(*args, **kwargs)
34-
35-
def __setattr__(self, name, value):
36-
if name.startswith('_'):
37-
super().__setattr__(name, value)
38-
else:
39-
setattr(self._tokenizer, name, value)
40-
41-
def __dir__(self):
42-
return dir(self._tokenizer)
43-
44-
45-
class ThreadSafeVAEWrapper:
46-
def __init__(self, vae, lock):
47-
self._vae = vae
48-
self._lock = lock
49-
50-
def __getattr__(self, name):
51-
attr = getattr(self._vae, name)
52-
# métodos que queremos proteger
53-
if name in {"decode", "encode", "forward"} and callable(attr):
54-
def wrapped(*args, **kwargs):
55-
with self._lock:
56-
return attr(*args, **kwargs)
57-
return wrapped
58-
return attr
59-
60-
def __setattr__(self, name, value):
61-
if name.startswith("_"):
62-
super().__setattr__(name, value)
63-
else:
64-
setattr(self._vae, name, value)
65-
66-
class ThreadSafeImageProcessorWrapper:
67-
def __init__(self, proc, lock):
68-
self._proc = proc
69-
self._lock = lock
70-
71-
def __getattr__(self, name):
72-
attr = getattr(self._proc, name)
73-
if name in {"postprocess", "preprocess"} and callable(attr):
74-
def wrapped(*args, **kwargs):
75-
with self._lock:
76-
return attr(*args, **kwargs)
77-
return wrapped
78-
return attr
79-
80-
def __setattr__(self, name, value):
81-
if name.startswith("_"):
82-
super().__setattr__(name, value)
83-
else:
84-
setattr(self._proc, name, value)
85-
8611
class RequestScopedPipeline:
8712
DEFAULT_MUTABLE_ATTRS = [
8813
"_all_hooks",
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
class ThreadSafeTokenizerWrapper:
2+
def __init__(self, tokenizer, lock):
3+
self._tokenizer = tokenizer
4+
self._lock = lock
5+
6+
self._thread_safe_methods = {
7+
'__call__', 'encode', 'decode', 'tokenize',
8+
'encode_plus', 'batch_encode_plus', 'batch_decode'
9+
}
10+
11+
def __getattr__(self, name):
12+
attr = getattr(self._tokenizer, name)
13+
14+
if name in self._thread_safe_methods and callable(attr):
15+
def wrapped_method(*args, **kwargs):
16+
with self._lock:
17+
return attr(*args, **kwargs)
18+
return wrapped_method
19+
20+
return attr
21+
22+
def __call__(self, *args, **kwargs):
23+
with self._lock:
24+
return self._tokenizer(*args, **kwargs)
25+
26+
def __setattr__(self, name, value):
27+
if name.startswith('_'):
28+
super().__setattr__(name, value)
29+
else:
30+
setattr(self._tokenizer, name, value)
31+
32+
def __dir__(self):
33+
return dir(self._tokenizer)
34+
35+
36+
class ThreadSafeVAEWrapper:
37+
def __init__(self, vae, lock):
38+
self._vae = vae
39+
self._lock = lock
40+
41+
def __getattr__(self, name):
42+
attr = getattr(self._vae, name)
43+
if name in {"decode", "encode", "forward"} and callable(attr):
44+
def wrapped(*args, **kwargs):
45+
with self._lock:
46+
return attr(*args, **kwargs)
47+
return wrapped
48+
return attr
49+
50+
def __setattr__(self, name, value):
51+
if name.startswith("_"):
52+
super().__setattr__(name, value)
53+
else:
54+
setattr(self._vae, name, value)
55+
56+
class ThreadSafeImageProcessorWrapper:
57+
def __init__(self, proc, lock):
58+
self._proc = proc
59+
self._lock = lock
60+
61+
def __getattr__(self, name):
62+
attr = getattr(self._proc, name)
63+
if name in {"postprocess", "preprocess"} and callable(attr):
64+
def wrapped(*args, **kwargs):
65+
with self._lock:
66+
return attr(*args, **kwargs)
67+
return wrapped
68+
return attr
69+
70+
def __setattr__(self, name, value):
71+
if name.startswith("_"):
72+
super().__setattr__(name, value)
73+
else:
74+
setattr(self._proc, name, value)

0 commit comments

Comments
 (0)