Skip to content

Commit 581847f

Browse files
Apply style fixes
1 parent 8072fba commit 581847f

File tree

2 files changed

+79
-56
lines changed

2 files changed

+79
-56
lines changed

examples/server-async/utils/requestscopedpipeline.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1-
from typing import Optional, Any, Iterable, List
21
import copy
32
import threading
3+
from typing import Any, Iterable, List, Optional
4+
45
import torch
6+
57
from diffusers.utils import logging
8+
69
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
7-
from .wrappers import ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper, ThreadSafeImageProcessorWrapper
10+
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
11+
812

913
logger = logging.get_logger(__name__)
1014

15+
1116
class RequestScopedPipeline:
1217
DEFAULT_MUTABLE_ATTRS = [
1318
"_all_hooks",
14-
"_offload_device",
19+
"_offload_device",
1520
"_progress_bar_config",
1621
"_progress_bar",
1722
"_rng_state",
@@ -29,42 +34,39 @@ def __init__(
2934
wrap_scheduler: bool = True,
3035
):
3136
self._base = pipeline
32-
33-
37+
3438
self.unet = getattr(pipeline, "unet", None)
35-
self.vae = getattr(pipeline, "vae", None)
39+
self.vae = getattr(pipeline, "vae", None)
3640
self.text_encoder = getattr(pipeline, "text_encoder", None)
3741
self.components = getattr(pipeline, "components", None)
38-
42+
3943
self.transformer = getattr(pipeline, "transformer", None)
40-
41-
if wrap_scheduler and hasattr(pipeline, 'scheduler') and pipeline.scheduler is not None:
44+
45+
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
4246
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
4347
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
4448

4549
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
46-
47-
50+
4851
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
4952

5053
self._vae_lock = threading.Lock()
5154
self._image_lock = threading.Lock()
52-
55+
5356
self._auto_detect_mutables = bool(auto_detect_mutables)
5457
self._tensor_numel_threshold = int(tensor_numel_threshold)
5558
self._auto_detected_attrs: List[str] = []
5659

5760
def _detect_kernel_pipeline(self, pipeline) -> bool:
5861
kernel_indicators = [
59-
'text_encoding_cache',
60-
'memory_manager',
61-
'enable_optimizations',
62-
'_create_request_context',
63-
'get_optimization_stats'
62+
"text_encoding_cache",
63+
"memory_manager",
64+
"enable_optimizations",
65+
"_create_request_context",
66+
"get_optimization_stats",
6467
]
65-
66-
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
6768

69+
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
6870

6971
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
7072
base_sched = getattr(self._base, "scheduler", None)
@@ -78,14 +80,12 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
7880

7981
try:
8082
return wrapped_scheduler.clone_for_request(
81-
num_inference_steps=num_inference_steps,
82-
device=device,
83-
**clone_kwargs
83+
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
8484
)
8585
except Exception as e:
8686
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
8787
try:
88-
if hasattr(wrapped_scheduler, 'scheduler'):
88+
if hasattr(wrapped_scheduler, "scheduler"):
8989
try:
9090
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
9191
return BaseAsyncScheduler(copied_scheduler)
@@ -95,8 +95,10 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
9595
copied_scheduler = copy.copy(wrapped_scheduler)
9696
return BaseAsyncScheduler(copied_scheduler)
9797
except Exception as e2:
98-
logger.warning(f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*).")
99-
return wrapped_scheduler
98+
logger.warning(
99+
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
100+
)
101+
return wrapped_scheduler
100102

101103
def _autodetect_mutables(self, max_attrs: int = 40):
102104
if not self._auto_detect_mutables:
@@ -107,16 +109,15 @@ def _autodetect_mutables(self, max_attrs: int = 40):
107109

108110
candidates: List[str] = []
109111
seen = set()
110-
111-
112+
112113
for name in dir(self._base):
113114
if name.startswith("__"):
114115
continue
115116
if name in self._mutable_attrs:
116117
continue
117118
if name in ("to", "save_pretrained", "from_pretrained"):
118119
continue
119-
120+
120121
try:
121122
val = getattr(self._base, name)
122123
except Exception:
@@ -165,7 +166,9 @@ def _clone_mutable_attrs(self, base, local):
165166
attrs_to_clone = list(self._mutable_attrs)
166167
attrs_to_clone.extend(self._autodetect_mutables())
167168

168-
EXCLUDE_ATTRS = {"components",}
169+
EXCLUDE_ATTRS = {
170+
"components",
171+
}
169172

170173
for attr in attrs_to_clone:
171174
if attr in EXCLUDE_ATTRS:
@@ -213,16 +216,16 @@ def _clone_mutable_attrs(self, base, local):
213216
def _is_tokenizer_component(self, component) -> bool:
214217
if component is None:
215218
return False
216-
217-
tokenizer_methods = ['encode', 'decode', 'tokenize', '__call__']
219+
220+
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
218221
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
219-
222+
220223
class_name = component.__class__.__name__.lower()
221-
has_tokenizer_in_name = 'tokenizer' in class_name
222-
223-
tokenizer_attrs = ['vocab_size', 'pad_token', 'eos_token', 'bos_token']
224+
has_tokenizer_in_name = "tokenizer" in class_name
225+
226+
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
224227
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
225-
228+
226229
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
227230

228231
def _should_wrap_tokenizers(self) -> bool:
@@ -238,11 +241,21 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
238241
local_pipe = copy.deepcopy(self._base)
239242

240243
try:
241-
if hasattr(local_pipe, "vae") and local_pipe.vae is not None and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper):
244+
if (
245+
hasattr(local_pipe, "vae")
246+
and local_pipe.vae is not None
247+
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
248+
):
242249
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
243250

244-
if hasattr(local_pipe, "image_processor") and local_pipe.image_processor is not None and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper):
245-
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(local_pipe.image_processor, self._image_lock)
251+
if (
252+
hasattr(local_pipe, "image_processor")
253+
and local_pipe.image_processor is not None
254+
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
255+
):
256+
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
257+
local_pipe.image_processor, self._image_lock
258+
)
246259
except Exception as e:
247260
logger.debug(f"Could not wrap vae/image_processor: {e}")
248261

@@ -253,7 +266,7 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
253266
num_inference_steps=num_inference_steps,
254267
device=device,
255268
return_scheduler=True,
256-
**{k: v for k, v in kwargs.items() if k in ['timesteps', 'sigmas']}
269+
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
257270
)
258271

259272
final_scheduler = BaseAsyncScheduler(configured_scheduler)
@@ -262,10 +275,9 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
262275
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
263276

264277
self._clone_mutable_attrs(self._base, local_pipe)
265-
266278

267279
original_tokenizers = {}
268-
280+
269281
if self._should_wrap_tokenizers():
270282
try:
271283
for name in dir(local_pipe):
@@ -281,7 +293,7 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
281293
for key, val in local_pipe.components.items():
282294
if val is None:
283295
continue
284-
296+
285297
if self._is_tokenizer_component(val):
286298
if not isinstance(val, ThreadSafeTokenizerWrapper):
287299
original_tokenizers[f"components[{key}]"] = val
@@ -293,9 +305,8 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
293305

294306
result = None
295307
cm = getattr(local_pipe, "model_cpu_offload_context", None)
296-
308+
297309
try:
298-
299310
if callable(cm):
300311
try:
301312
with cm():
@@ -316,8 +327,8 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
316327
try:
317328
for name, tok in original_tokenizers.items():
318329
if name.startswith("components["):
319-
key = name[len("components["):-1]
320-
if hasattr(local_pipe, 'components') and isinstance(local_pipe.components, dict):
330+
key = name[len("components[") : -1]
331+
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
321332
local_pipe.components[key] = tok
322333
else:
323334
setattr(local_pipe, name, tok)

examples/server-async/utils/wrappers.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,38 @@ def __init__(self, tokenizer, lock):
44
self._lock = lock
55

66
self._thread_safe_methods = {
7-
'__call__', 'encode', 'decode', 'tokenize',
8-
'encode_plus', 'batch_encode_plus', 'batch_decode'
7+
"__call__",
8+
"encode",
9+
"decode",
10+
"tokenize",
11+
"encode_plus",
12+
"batch_encode_plus",
13+
"batch_decode",
914
}
10-
15+
1116
def __getattr__(self, name):
1217
attr = getattr(self._tokenizer, name)
13-
18+
1419
if name in self._thread_safe_methods and callable(attr):
20+
1521
def wrapped_method(*args, **kwargs):
1622
with self._lock:
1723
return attr(*args, **kwargs)
24+
1825
return wrapped_method
19-
26+
2027
return attr
2128

2229
def __call__(self, *args, **kwargs):
2330
with self._lock:
2431
return self._tokenizer(*args, **kwargs)
25-
32+
2633
def __setattr__(self, name, value):
27-
if name.startswith('_'):
34+
if name.startswith("_"):
2835
super().__setattr__(name, value)
2936
else:
3037
setattr(self._tokenizer, name, value)
31-
38+
3239
def __dir__(self):
3340
return dir(self._tokenizer)
3441

@@ -41,9 +48,11 @@ def __init__(self, vae, lock):
4148
def __getattr__(self, name):
4249
attr = getattr(self._vae, name)
4350
if name in {"decode", "encode", "forward"} and callable(attr):
51+
4452
def wrapped(*args, **kwargs):
4553
with self._lock:
4654
return attr(*args, **kwargs)
55+
4756
return wrapped
4857
return attr
4958

@@ -53,6 +62,7 @@ def __setattr__(self, name, value):
5362
else:
5463
setattr(self._vae, name, value)
5564

65+
5666
class ThreadSafeImageProcessorWrapper:
5767
def __init__(self, proc, lock):
5868
self._proc = proc
@@ -61,14 +71,16 @@ def __init__(self, proc, lock):
6171
def __getattr__(self, name):
6272
attr = getattr(self._proc, name)
6373
if name in {"postprocess", "preprocess"} and callable(attr):
74+
6475
def wrapped(*args, **kwargs):
6576
with self._lock:
6677
return attr(*args, **kwargs)
78+
6779
return wrapped
6880
return attr
6981

7082
def __setattr__(self, name, value):
7183
if name.startswith("_"):
7284
super().__setattr__(name, value)
7385
else:
74-
setattr(self._proc, name, value)
86+
setattr(self._proc, name, value)

0 commit comments

Comments
 (0)