Skip to content

Commit 06bb136

Browse files
The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async
1 parent a9666b1 commit 06bb136

38 files changed

+498
-788
lines changed

examples/server-async/DiffusersServer/serverasync.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
from fastapi.middleware.cors import CORSMiddleware
44
from fastapi.concurrency import run_in_threadpool
55
from pydantic import BaseModel
6-
from .Pipelines import TextToImagePipelineSD3, TextToImagePipelineFlux, TextToImagePipelineSD
6+
from .Pipelines import TextToImagePipelineSD3, TextToImagePipelineFlux, TextToImagePipelineSD, logger
77
import logging
8-
from diffusers.pipelines.pipeline_utils import RequestScopedPipeline
8+
from ..utils import RequestScopedPipeline
99
from diffusers import *
10-
from .superpipeline import *
1110
import random
1211
import uuid
1312
import tempfile
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .requestscopedpipeline import RequestScopedPipeline
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from typing import Optional, Any, Iterable, List
2+
import copy
3+
import threading
4+
import torch
5+
from diffusers.utils import logging
6+
7+
logger = logging.get_logger(__name__)
8+
9+
def safe_tokenize(tokenizer, *args, lock, **kwargs):
10+
with lock:
11+
return tokenizer(*args, **kwargs)
12+
13+
class RequestScopedPipeline:
14+
DEFAULT_MUTABLE_ATTRS = [
15+
"_all_hooks",
16+
"_offload_device",
17+
"_progress_bar_config",
18+
"_progress_bar",
19+
"_rng_state",
20+
"_last_seed",
21+
"latents",
22+
]
23+
24+
def __init__(
25+
self,
26+
pipeline: Any,
27+
mutable_attrs: Optional[Iterable[str]] = None,
28+
auto_detect_mutables: bool = True,
29+
tensor_numel_threshold: int = 1_000_000,
30+
tokenizer_lock: Optional[threading.Lock] = None
31+
):
32+
self._base = pipeline
33+
self.unet = getattr(pipeline, "unet", None)
34+
self.vae = getattr(pipeline, "vae", None)
35+
self.text_encoder = getattr(pipeline, "text_encoder", None)
36+
self.components = getattr(pipeline, "components", None)
37+
38+
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
39+
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
40+
41+
self._auto_detect_mutables = bool(auto_detect_mutables)
42+
self._tensor_numel_threshold = int(tensor_numel_threshold)
43+
44+
self._auto_detected_attrs: List[str] = []
45+
46+
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
47+
base_sched = getattr(self._base, "scheduler", None)
48+
if base_sched is None:
49+
return None
50+
51+
if hasattr(base_sched, "clone_for_request"):
52+
try:
53+
return base_sched.clone_for_request(num_inference_steps=num_inference_steps, device=device, **clone_kwargs)
54+
except Exception as e:
55+
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
56+
57+
try:
58+
return copy.deepcopy(base_sched)
59+
except Exception as e:
60+
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
61+
return base_sched
62+
63+
def _autodetect_mutables(self, max_attrs: int = 40):
64+
if not self._auto_detect_mutables:
65+
return []
66+
67+
if self._auto_detected_attrs:
68+
return self._auto_detected_attrs
69+
70+
candidates: List[str] = []
71+
seen = set()
72+
for name in dir(self._base):
73+
if name.startswith("__"):
74+
continue
75+
if name in self._mutable_attrs:
76+
continue
77+
if name in ("to", "save_pretrained", "from_pretrained"):
78+
continue
79+
try:
80+
val = getattr(self._base, name)
81+
except Exception:
82+
continue
83+
84+
import types
85+
86+
# skip callables and modules
87+
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
88+
continue
89+
90+
# containers -> candidate
91+
if isinstance(val, (dict, list, set, tuple, bytearray)):
92+
candidates.append(name)
93+
seen.add(name)
94+
else:
95+
# try Tensor detection
96+
try:
97+
if isinstance(val, torch.Tensor):
98+
if val.numel() <= self._tensor_numel_threshold:
99+
candidates.append(name)
100+
seen.add(name)
101+
else:
102+
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
103+
except Exception:
104+
continue
105+
106+
if len(candidates) >= max_attrs:
107+
break
108+
109+
self._auto_detected_attrs = candidates
110+
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
111+
return self._auto_detected_attrs
112+
113+
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
114+
try:
115+
cls = type(base_obj)
116+
descriptor = getattr(cls, attr_name, None)
117+
if isinstance(descriptor, property):
118+
return descriptor.fset is None
119+
if hasattr(descriptor, "__set__") is False and descriptor is not None:
120+
return False
121+
except Exception:
122+
pass
123+
return False
124+
125+
def _clone_mutable_attrs(self, base, local):
126+
attrs_to_clone = list(self._mutable_attrs)
127+
attrs_to_clone.extend(self._autodetect_mutables())
128+
129+
EXCLUDE_ATTRS = {"components",}
130+
131+
for attr in attrs_to_clone:
132+
if attr in EXCLUDE_ATTRS:
133+
logger.debug(f"Skipping excluded attr '{attr}'")
134+
continue
135+
if not hasattr(base, attr):
136+
continue
137+
if self._is_readonly_property(base, attr):
138+
logger.debug(f"Skipping read-only property '{attr}'")
139+
continue
140+
141+
try:
142+
val = getattr(base, attr)
143+
except Exception as e:
144+
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
145+
continue
146+
147+
try:
148+
if isinstance(val, dict):
149+
setattr(local, attr, dict(val))
150+
elif isinstance(val, (list, tuple, set)):
151+
setattr(local, attr, list(val))
152+
elif isinstance(val, bytearray):
153+
setattr(local, attr, bytearray(val))
154+
else:
155+
# small tensors or atomic values
156+
if isinstance(val, torch.Tensor):
157+
if val.numel() <= self._tensor_numel_threshold:
158+
setattr(local, attr, val.clone())
159+
else:
160+
# don't clone big tensors, keep reference
161+
setattr(local, attr, val)
162+
else:
163+
try:
164+
setattr(local, attr, copy.copy(val))
165+
except Exception:
166+
setattr(local, attr, val)
167+
except (AttributeError, TypeError) as e:
168+
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
169+
continue
170+
except Exception as e:
171+
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
172+
continue
173+
174+
def _is_tokenizer_component(self, component) -> bool:
175+
if component is None:
176+
return False
177+
178+
tokenizer_methods = ['encode', 'decode', 'tokenize', '__call__']
179+
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
180+
181+
class_name = component.__class__.__name__.lower()
182+
has_tokenizer_in_name = 'tokenizer' in class_name
183+
184+
tokenizer_attrs = ['vocab_size', 'pad_token', 'eos_token', 'bos_token']
185+
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
186+
187+
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
188+
189+
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
190+
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
191+
192+
try:
193+
local_pipe = copy.copy(self._base)
194+
except Exception as e:
195+
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
196+
local_pipe = copy.deepcopy(self._base)
197+
198+
if local_scheduler is not None:
199+
try:
200+
setattr(local_pipe, "scheduler", local_scheduler)
201+
except Exception:
202+
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
203+
204+
self._clone_mutable_attrs(self._base, local_pipe)
205+
206+
# 4) wrap tokenizers on the local pipe with the lock wrapper
207+
tokenizer_wrappers = {} # name -> original_tokenizer
208+
try:
209+
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
210+
for name in dir(local_pipe):
211+
if "tokenizer" in name and not name.startswith("_"):
212+
tok = getattr(local_pipe, name, None)
213+
if tok is not None and self._is_tokenizer_component(tok):
214+
tokenizer_wrappers[name] = tok
215+
setattr(
216+
local_pipe,
217+
name,
218+
lambda *args, tok=tok, **kwargs: safe_tokenize(tok, *args, lock=self._tokenizer_lock, **kwargs)
219+
)
220+
221+
# b) wrap tokenizers in components dict
222+
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
223+
for key, val in local_pipe.components.items():
224+
if val is None:
225+
continue
226+
227+
if self._is_tokenizer_component(val):
228+
tokenizer_wrappers[f"components[{key}]"] = val
229+
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
230+
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
231+
)
232+
233+
except Exception as e:
234+
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
235+
236+
result = None
237+
cm = getattr(local_pipe, "model_cpu_offload_context", None)
238+
try:
239+
if callable(cm):
240+
try:
241+
with cm():
242+
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
243+
except TypeError:
244+
# cm might be a context manager instance rather than callable
245+
try:
246+
with cm:
247+
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
248+
except Exception as e:
249+
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
250+
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
251+
else:
252+
# no offload context available — call directly
253+
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
254+
255+
return result
256+
257+
finally:
258+
try:
259+
for name, tok in tokenizer_wrappers.items():
260+
if name.startswith("components["):
261+
key = name[len("components["):-1]
262+
local_pipe.components[key] = tok
263+
else:
264+
setattr(local_pipe, name, tok)
265+
except Exception as e:
266+
logger.debug(f"Error restoring wrapped tokenizers: {e}")

0 commit comments

Comments
 (0)