diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py index 57d1e2567169..9c3276c31c69 100644 --- a/examples/server-async/utils/requestscopedpipeline.py +++ b/examples/server-async/utils/requestscopedpipeline.py @@ -7,16 +7,12 @@ from diffusers.utils import logging from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps +from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper logger = logging.get_logger(__name__) -def safe_tokenize(tokenizer, *args, lock, **kwargs): - with lock: - return tokenizer(*args, **kwargs) - - class RequestScopedPipeline: DEFAULT_MUTABLE_ATTRS = [ "_all_hooks", @@ -38,23 +34,40 @@ def __init__( wrap_scheduler: bool = True, ): self._base = pipeline + self.unet = getattr(pipeline, "unet", None) self.vae = getattr(pipeline, "vae", None) self.text_encoder = getattr(pipeline, "text_encoder", None) self.components = getattr(pipeline, "components", None) + self.transformer = getattr(pipeline, "transformer", None) + if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None: if not isinstance(pipeline.scheduler, BaseAsyncScheduler): pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler) self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS) + self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock() + self._vae_lock = threading.Lock() + self._image_lock = threading.Lock() + self._auto_detect_mutables = bool(auto_detect_mutables) self._tensor_numel_threshold = int(tensor_numel_threshold) - self._auto_detected_attrs: List[str] = [] + def _detect_kernel_pipeline(self, pipeline) -> bool: + kernel_indicators = [ + "text_encoding_cache", + "memory_manager", + "enable_optimizations", + "_create_request_context", + "get_optimization_stats", + ] + + return any(hasattr(pipeline, attr) for attr in kernel_indicators) + def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs): base_sched = getattr(self._base, "scheduler", None) if base_sched is None: @@ -70,11 +83,21 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] num_inference_steps=num_inference_steps, device=device, **clone_kwargs ) except Exception as e: - logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()") + logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback") try: - return copy.deepcopy(wrapped_scheduler) - except Exception as e: - logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).") + if hasattr(wrapped_scheduler, "scheduler"): + try: + copied_scheduler = copy.copy(wrapped_scheduler.scheduler) + return BaseAsyncScheduler(copied_scheduler) + except Exception: + return wrapped_scheduler + else: + copied_scheduler = copy.copy(wrapped_scheduler) + return BaseAsyncScheduler(copied_scheduler) + except Exception as e2: + logger.warning( + f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)." + ) return wrapped_scheduler def _autodetect_mutables(self, max_attrs: int = 40): @@ -86,6 +109,7 @@ def _autodetect_mutables(self, max_attrs: int = 40): candidates: List[str] = [] seen = set() + for name in dir(self._base): if name.startswith("__"): continue @@ -93,6 +117,7 @@ def _autodetect_mutables(self, max_attrs: int = 40): continue if name in ("to", "save_pretrained", "from_pretrained"): continue + try: val = getattr(self._base, name) except Exception: @@ -100,11 +125,9 @@ def _autodetect_mutables(self, max_attrs: int = 40): import types - # skip callables and modules if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)): continue - # containers -> candidate if isinstance(val, (dict, list, set, tuple, bytearray)): candidates.append(name) seen.add(name) @@ -205,6 +228,9 @@ def _is_tokenizer_component(self, component) -> bool: return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs) + def _should_wrap_tokenizers(self) -> bool: + return True + def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs): local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device) @@ -214,6 +240,25 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).") local_pipe = copy.deepcopy(self._base) + try: + if ( + hasattr(local_pipe, "vae") + and local_pipe.vae is not None + and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper) + ): + local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock) + + if ( + hasattr(local_pipe, "image_processor") + and local_pipe.image_processor is not None + and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper) + ): + local_pipe.image_processor = ThreadSafeImageProcessorWrapper( + local_pipe.image_processor, self._image_lock + ) + except Exception as e: + logger.debug(f"Could not wrap vae/image_processor: {e}") + if local_scheduler is not None: try: timesteps, num_steps, configured_scheduler = async_retrieve_timesteps( @@ -231,47 +276,42 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = self._clone_mutable_attrs(self._base, local_pipe) - # 4) wrap tokenizers on the local pipe with the lock wrapper - tokenizer_wrappers = {} # name -> original_tokenizer - try: - # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...) - for name in dir(local_pipe): - if "tokenizer" in name and not name.startswith("_"): - tok = getattr(local_pipe, name, None) - if tok is not None and self._is_tokenizer_component(tok): - tokenizer_wrappers[name] = tok - setattr( - local_pipe, - name, - lambda *args, tok=tok, **kwargs: safe_tokenize( - tok, *args, lock=self._tokenizer_lock, **kwargs - ), - ) - - # b) wrap tokenizers in components dict - if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): - for key, val in local_pipe.components.items(): - if val is None: - continue - - if self._is_tokenizer_component(val): - tokenizer_wrappers[f"components[{key}]"] = val - local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize( - tokenizer, *args, lock=self._tokenizer_lock, **kwargs - ) + original_tokenizers = {} - except Exception as e: - logger.debug(f"Tokenizer wrapping step encountered an error: {e}") + if self._should_wrap_tokenizers(): + try: + for name in dir(local_pipe): + if "tokenizer" in name and not name.startswith("_"): + tok = getattr(local_pipe, name, None) + if tok is not None and self._is_tokenizer_component(tok): + if not isinstance(tok, ThreadSafeTokenizerWrapper): + original_tokenizers[name] = tok + wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock) + setattr(local_pipe, name, wrapped_tokenizer) + + if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): + for key, val in local_pipe.components.items(): + if val is None: + continue + + if self._is_tokenizer_component(val): + if not isinstance(val, ThreadSafeTokenizerWrapper): + original_tokenizers[f"components[{key}]"] = val + wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock) + local_pipe.components[key] = wrapped_tokenizer + + except Exception as e: + logger.debug(f"Tokenizer wrapping step encountered an error: {e}") result = None cm = getattr(local_pipe, "model_cpu_offload_context", None) + try: if callable(cm): try: with cm(): result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) except TypeError: - # cm might be a context manager instance rather than callable try: with cm: result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) @@ -279,18 +319,18 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.") result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) else: - # no offload context available — call directly result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) return result finally: try: - for name, tok in tokenizer_wrappers.items(): + for name, tok in original_tokenizers.items(): if name.startswith("components["): key = name[len("components[") : -1] - local_pipe.components[key] = tok + if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): + local_pipe.components[key] = tok else: setattr(local_pipe, name, tok) except Exception as e: - logger.debug(f"Error restoring wrapped tokenizers: {e}") + logger.debug(f"Error restoring original tokenizers: {e}") diff --git a/examples/server-async/utils/wrappers.py b/examples/server-async/utils/wrappers.py new file mode 100644 index 000000000000..1e8474eabf3f --- /dev/null +++ b/examples/server-async/utils/wrappers.py @@ -0,0 +1,86 @@ +class ThreadSafeTokenizerWrapper: + def __init__(self, tokenizer, lock): + self._tokenizer = tokenizer + self._lock = lock + + self._thread_safe_methods = { + "__call__", + "encode", + "decode", + "tokenize", + "encode_plus", + "batch_encode_plus", + "batch_decode", + } + + def __getattr__(self, name): + attr = getattr(self._tokenizer, name) + + if name in self._thread_safe_methods and callable(attr): + + def wrapped_method(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + return wrapped_method + + return attr + + def __call__(self, *args, **kwargs): + with self._lock: + return self._tokenizer(*args, **kwargs) + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + else: + setattr(self._tokenizer, name, value) + + def __dir__(self): + return dir(self._tokenizer) + + +class ThreadSafeVAEWrapper: + def __init__(self, vae, lock): + self._vae = vae + self._lock = lock + + def __getattr__(self, name): + attr = getattr(self._vae, name) + if name in {"decode", "encode", "forward"} and callable(attr): + + def wrapped(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + return wrapped + return attr + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + else: + setattr(self._vae, name, value) + + +class ThreadSafeImageProcessorWrapper: + def __init__(self, proc, lock): + self._proc = proc + self._lock = lock + + def __getattr__(self, name): + attr = getattr(self._proc, name) + if name in {"postprocess", "preprocess"} and callable(attr): + + def wrapped(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + return wrapped + return attr + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + else: + setattr(self._proc, name, value)