diff --git a/examples/server-async/Pipelines.py b/examples/server-async/Pipelines.py new file mode 100644 index 000000000000..f89cac6a7e4b --- /dev/null +++ b/examples/server-async/Pipelines.py @@ -0,0 +1,91 @@ +import logging +import os +from dataclasses import dataclass, field +from typing import List + +import torch +from pydantic import BaseModel + +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline + + +logger = logging.getLogger(__name__) + + +class TextToImageInput(BaseModel): + model: str + prompt: str + size: str | None = None + n: int | None = None + + +@dataclass +class PresetModels: + SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"]) + SD3_5: List[str] = field( + default_factory=lambda: [ + "stabilityai/stable-diffusion-3.5-large", + "stabilityai/stable-diffusion-3.5-large-turbo", + "stabilityai/stable-diffusion-3.5-medium", + ] + ) + + +class TextToImagePipelineSD3: + def __init__(self, model_path: str | None = None): + self.model_path = model_path or os.getenv("MODEL_PATH") + self.pipeline: StableDiffusion3Pipeline | None = None + self.device: str | None = None + + def start(self): + if torch.cuda.is_available(): + model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large" + logger.info("Loading CUDA") + self.device = "cuda" + self.pipeline = StableDiffusion3Pipeline.from_pretrained( + model_path, + torch_dtype=torch.float16, + ).to(device=self.device) + elif torch.backends.mps.is_available(): + model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium" + logger.info("Loading MPS for Mac M Series") + self.device = "mps" + self.pipeline = StableDiffusion3Pipeline.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + ).to(device=self.device) + else: + raise Exception("No CUDA or MPS device available") + + +class ModelPipelineInitializer: + def __init__(self, model: str = "", type_models: str = "t2im"): + self.model = model + self.type_models = type_models + self.pipeline = None + self.device = "cuda" if torch.cuda.is_available() else "mps" + self.model_type = None + + def initialize_pipeline(self): + if not self.model: + raise ValueError("Model name not provided") + + # Check if model exists in PresetModels + preset_models = PresetModels() + + # Determine which model type we're dealing with + if self.model in preset_models.SD3: + self.model_type = "SD3" + elif self.model in preset_models.SD3_5: + self.model_type = "SD3_5" + + # Create appropriate pipeline based on model type and type_models + if self.type_models == "t2im": + if self.model_type in ["SD3", "SD3_5"]: + self.pipeline = TextToImagePipelineSD3(self.model) + else: + raise ValueError(f"Model type {self.model_type} not supported for text-to-image") + elif self.type_models == "t2v": + raise ValueError(f"Unsupported type_models: {self.type_models}") + + return self.pipeline diff --git a/examples/server-async/README.md b/examples/server-async/README.md new file mode 100644 index 000000000000..a47ab7c7f224 --- /dev/null +++ b/examples/server-async/README.md @@ -0,0 +1,171 @@ +# Asynchronous server and parallel execution of models + +> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines. +> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.) + +## ⚠️ IMPORTANT + +* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU. + +## Necessary components + +All the components needed to create the inference server are in the current directory: + +``` +server-async/ +├── utils/ +├─────── __init__.py +├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences +├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model +├─────── utils.py # Image/video saving utilities and service configuration +├── Pipelines.py # pipeline loader classes (SD3) +├── serverasync.py # FastAPI app with lifespan management and async inference endpoints +├── test.py # Client test script for inference requests +├── requirements.txt # Dependencies +└── README.md # This documentation +``` + +## What `diffusers-async` adds / Why we needed it + +Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request. + +`diffusers-async` / this example addresses that by: + +* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*. +* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics. +* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur. +* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API. +* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors. +* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`). + +## How the server works (high-level flow) + +1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts. +2. On each HTTP inference request: + + * The server uses `RequestScopedPipeline.generate(...)` which: + + * automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped), + * obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`), + * does `local_pipe = copy.copy(base_pipe)` (shallow copy), + * sets `local_pipe.scheduler = local_scheduler` (if possible), + * clones only small mutable attributes (callbacks, rng, small latents) with auto-detection, + * wraps tokenizers with thread-safe locks to prevent race conditions, + * optionally enters a `model_cpu_offload_context()` for memory offload hooks, + * calls the pipeline on the local view (`local_pipe(...)`). +3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`). +4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state. + +## How to set up and run the server + +### 1) Install dependencies + +Recommended: create a virtualenv / conda environment. + +```bash +pip install diffusers +pip install -r requirements.txt +``` + +### 2) Start the server + +Using the `serverasync.py` file that already has everything you need: + +```bash +python serverasync.py +``` + +The server will start on `http://localhost:8500` by default with the following features: +- FastAPI application with async lifespan management +- Automatic model loading and pipeline initialization +- Request counting and active inference tracking +- Memory cleanup after each inference +- CORS middleware for cross-origin requests + +### 3) Test the server + +Use the included test script: + +```bash +python test.py +``` + +Or send a manual request: + +`POST /api/diffusers/inference` with JSON body: + +```json +{ + "prompt": "A futuristic cityscape, vibrant colors", + "num_inference_steps": 30, + "num_images_per_prompt": 1 +} +``` + +Response example: + +```json +{ + "response": ["http://localhost:8500/images/img123.png"] +} +``` + +### 4) Server endpoints + +- `GET /` - Welcome message +- `POST /api/diffusers/inference` - Main inference endpoint +- `GET /images/{filename}` - Serve generated images +- `GET /api/status` - Server status and memory info + +## Advanced Configuration + +### RequestScopedPipeline Parameters + +```python +RequestScopedPipeline( + pipeline, # Base pipeline to wrap + mutable_attrs=None, # Custom list of attributes to clone + auto_detect_mutables=True, # Enable automatic detection of mutable attributes + tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning + tokenizer_lock=None, # Custom threading lock for tokenizers + wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler +) +``` + +### BaseAsyncScheduler Features + +* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__` +* `clone_for_request()` method for safe per-request scheduler cloning +* Enhanced debugging with `__repr__` and `__str__` methods +* Full compatibility with existing scheduler APIs + +### Server Configuration + +The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass: + +```python +@dataclass +class ServerConfigModels: + model: str = 'stabilityai/stable-diffusion-3.5-medium' + type_models: str = 't2im' + host: str = '0.0.0.0' + port: int = 8500 +``` + +## Troubleshooting (quick) + +* `Already borrowed` — previously a Rust tokenizer concurrency error. + ✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen. + +* `can't set attribute 'components'` — pipeline exposes read-only `components`. + ✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically. + +* Scheduler issues: + * If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler. + ✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged. + +* Memory issues with large tensors: + ✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones. + +* Automatic tokenizer detection: + ✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers. \ No newline at end of file diff --git a/examples/server-async/requirements.txt b/examples/server-async/requirements.txt new file mode 100644 index 000000000000..aafa93b7023f --- /dev/null +++ b/examples/server-async/requirements.txt @@ -0,0 +1,10 @@ +torch +torchvision +transformers +sentencepiece +fastapi +uvicorn +ftfy +accelerate +xformers +protobuf \ No newline at end of file diff --git a/examples/server-async/serverasync.py b/examples/server-async/serverasync.py new file mode 100644 index 000000000000..b279b36f9a84 --- /dev/null +++ b/examples/server-async/serverasync.py @@ -0,0 +1,230 @@ +import asyncio +import gc +import logging +import os +import random +import threading +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, Dict, Optional, Type + +import torch +from fastapi import FastAPI, HTTPException, Request +from fastapi.concurrency import run_in_threadpool +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from Pipelines import ModelPipelineInitializer +from pydantic import BaseModel + +from utils import RequestScopedPipeline, Utils + + +@dataclass +class ServerConfigModels: + model: str = "stabilityai/stable-diffusion-3.5-medium" + type_models: str = "t2im" + constructor_pipeline: Optional[Type] = None + custom_pipeline: Optional[Type] = None + components: Optional[Dict[str, Any]] = None + torch_dtype: Optional[torch.dtype] = None + host: str = "0.0.0.0" + port: int = 8500 + + +server_config = ServerConfigModels() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logging.basicConfig(level=logging.INFO) + app.state.logger = logging.getLogger("diffusers-server") + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + + app.state.total_requests = 0 + app.state.active_inferences = 0 + app.state.metrics_lock = asyncio.Lock() + app.state.metrics_task = None + + app.state.utils_app = Utils( + host=server_config.host, + port=server_config.port, + ) + + async def metrics_loop(): + try: + while True: + async with app.state.metrics_lock: + total = app.state.total_requests + active = app.state.active_inferences + app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}") + await asyncio.sleep(5) + except asyncio.CancelledError: + app.state.logger.info("Metrics loop cancelled") + raise + + app.state.metrics_task = asyncio.create_task(metrics_loop()) + + try: + yield + finally: + task = app.state.metrics_task + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + try: + stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None) + if callable(stop_fn): + await run_in_threadpool(stop_fn) + except Exception as e: + app.state.logger.warning(f"Error during pipeline shutdown: {e}") + + app.state.logger.info("Lifespan shutdown complete") + + +app = FastAPI(lifespan=lifespan) + +logger = logging.getLogger("DiffusersServer.Pipelines") + + +initializer = ModelPipelineInitializer( + model=server_config.model, + type_models=server_config.type_models, +) +model_pipeline = initializer.initialize_pipeline() +model_pipeline.start() + +request_pipe = RequestScopedPipeline(model_pipeline.pipeline) +pipeline_lock = threading.Lock() + +logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})") + +app.state.MODEL_INITIALIZER = initializer +app.state.MODEL_PIPELINE = model_pipeline +app.state.REQUEST_PIPE = request_pipe +app.state.PIPELINE_LOCK = pipeline_lock + + +class JSONBodyQueryAPI(BaseModel): + model: str | None = None + prompt: str + negative_prompt: str | None = None + num_inference_steps: int = 28 + num_images_per_prompt: int = 1 + + +@app.middleware("http") +async def count_requests_middleware(request: Request, call_next): + async with app.state.metrics_lock: + app.state.total_requests += 1 + response = await call_next(request) + return response + + +@app.get("/") +async def root(): + return {"message": "Welcome to the Diffusers Server"} + + +@app.post("/api/diffusers/inference") +async def api(json: JSONBodyQueryAPI): + prompt = json.prompt + negative_prompt = json.negative_prompt or "" + num_steps = json.num_inference_steps + num_images_per_prompt = json.num_images_per_prompt + + wrapper = app.state.MODEL_PIPELINE + initializer = app.state.MODEL_INITIALIZER + + utils_app = app.state.utils_app + + if not wrapper or not wrapper.pipeline: + raise HTTPException(500, "Model not initialized correctly") + if not prompt.strip(): + raise HTTPException(400, "No prompt provided") + + def make_generator(): + g = torch.Generator(device=initializer.device) + return g.manual_seed(random.randint(0, 10_000_000)) + + req_pipe = app.state.REQUEST_PIPE + + def infer(): + gen = make_generator() + return req_pipe.generate( + prompt=prompt, + negative_prompt=negative_prompt, + generator=gen, + num_inference_steps=num_steps, + num_images_per_prompt=num_images_per_prompt, + device=initializer.device, + output_type="pil", + ) + + try: + async with app.state.metrics_lock: + app.state.active_inferences += 1 + + output = await run_in_threadpool(infer) + + async with app.state.metrics_lock: + app.state.active_inferences = max(0, app.state.active_inferences - 1) + + urls = [utils_app.save_image(img) for img in output.images] + return {"response": urls} + + except Exception as e: + async with app.state.metrics_lock: + app.state.active_inferences = max(0, app.state.active_inferences - 1) + logger.error(f"Error during inference: {e}") + raise HTTPException(500, f"Error in processing: {e}") + + finally: + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.ipc_collect() + gc.collect() + + +@app.get("/images/{filename}") +async def serve_image(filename: str): + utils_app = app.state.utils_app + file_path = os.path.join(utils_app.image_dir, filename) + if not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="Image not found") + return FileResponse(file_path, media_type="image/png") + + +@app.get("/api/status") +async def get_status(): + memory_info = {} + if torch.cuda.is_available(): + memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB + memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB + memory_info = { + "memory_allocated_gb": round(memory_allocated, 2), + "memory_reserved_gb": round(memory_reserved, 2), + "device": torch.cuda.get_device_name(0), + } + + return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info} + + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host=server_config.host, port=server_config.port) diff --git a/examples/server-async/test.py b/examples/server-async/test.py new file mode 100644 index 000000000000..e67317ea8f6b --- /dev/null +++ b/examples/server-async/test.py @@ -0,0 +1,65 @@ +import os +import time +import urllib.parse + +import requests + + +SERVER_URL = "http://localhost:8500/api/diffusers/inference" +BASE_URL = "http://localhost:8500" +DOWNLOAD_FOLDER = "generated_images" +WAIT_BEFORE_DOWNLOAD = 2 # seconds + +os.makedirs(DOWNLOAD_FOLDER, exist_ok=True) + + +def save_from_url(url: str) -> str: + """Download the given URL (relative or absolute) and save it locally.""" + if url.startswith("/"): + direct = BASE_URL.rstrip("/") + url + else: + direct = url + resp = requests.get(direct, timeout=60) + resp.raise_for_status() + filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png" + path = os.path.join(DOWNLOAD_FOLDER, filename) + with open(path, "wb") as f: + f.write(resp.content) + return path + + +def main(): + payload = { + "prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style", + "num_inference_steps": 30, + "num_images_per_prompt": 1, + } + + print("Sending request...") + try: + r = requests.post(SERVER_URL, json=payload, timeout=480) + r.raise_for_status() + except Exception as e: + print(f"Request failed: {e}") + return + + body = r.json().get("response", []) + # Normalize to a list + urls = body if isinstance(body, list) else [body] if body else [] + if not urls: + print("No URLs found in the response. Check the server output.") + return + + print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...") + time.sleep(WAIT_BEFORE_DOWNLOAD) + + for u in urls: + try: + path = save_from_url(u) + print(f"Image saved to: {path}") + except Exception as e: + print(f"Error downloading {u}: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/server-async/utils/__init__.py b/examples/server-async/utils/__init__.py new file mode 100644 index 000000000000..731cfe491ae5 --- /dev/null +++ b/examples/server-async/utils/__init__.py @@ -0,0 +1,2 @@ +from .requestscopedpipeline import RequestScopedPipeline +from .utils import Utils diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py new file mode 100644 index 000000000000..57d1e2567169 --- /dev/null +++ b/examples/server-async/utils/requestscopedpipeline.py @@ -0,0 +1,296 @@ +import copy +import threading +from typing import Any, Iterable, List, Optional + +import torch + +from diffusers.utils import logging + +from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps + + +logger = logging.get_logger(__name__) + + +def safe_tokenize(tokenizer, *args, lock, **kwargs): + with lock: + return tokenizer(*args, **kwargs) + + +class RequestScopedPipeline: + DEFAULT_MUTABLE_ATTRS = [ + "_all_hooks", + "_offload_device", + "_progress_bar_config", + "_progress_bar", + "_rng_state", + "_last_seed", + "latents", + ] + + def __init__( + self, + pipeline: Any, + mutable_attrs: Optional[Iterable[str]] = None, + auto_detect_mutables: bool = True, + tensor_numel_threshold: int = 1_000_000, + tokenizer_lock: Optional[threading.Lock] = None, + wrap_scheduler: bool = True, + ): + self._base = pipeline + self.unet = getattr(pipeline, "unet", None) + self.vae = getattr(pipeline, "vae", None) + self.text_encoder = getattr(pipeline, "text_encoder", None) + self.components = getattr(pipeline, "components", None) + + if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None: + if not isinstance(pipeline.scheduler, BaseAsyncScheduler): + pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler) + + self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS) + self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock() + + self._auto_detect_mutables = bool(auto_detect_mutables) + self._tensor_numel_threshold = int(tensor_numel_threshold) + + self._auto_detected_attrs: List[str] = [] + + def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs): + base_sched = getattr(self._base, "scheduler", None) + if base_sched is None: + return None + + if not isinstance(base_sched, BaseAsyncScheduler): + wrapped_scheduler = BaseAsyncScheduler(base_sched) + else: + wrapped_scheduler = base_sched + + try: + return wrapped_scheduler.clone_for_request( + num_inference_steps=num_inference_steps, device=device, **clone_kwargs + ) + except Exception as e: + logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()") + try: + return copy.deepcopy(wrapped_scheduler) + except Exception as e: + logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).") + return wrapped_scheduler + + def _autodetect_mutables(self, max_attrs: int = 40): + if not self._auto_detect_mutables: + return [] + + if self._auto_detected_attrs: + return self._auto_detected_attrs + + candidates: List[str] = [] + seen = set() + for name in dir(self._base): + if name.startswith("__"): + continue + if name in self._mutable_attrs: + continue + if name in ("to", "save_pretrained", "from_pretrained"): + continue + try: + val = getattr(self._base, name) + except Exception: + continue + + import types + + # skip callables and modules + if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)): + continue + + # containers -> candidate + if isinstance(val, (dict, list, set, tuple, bytearray)): + candidates.append(name) + seen.add(name) + else: + # try Tensor detection + try: + if isinstance(val, torch.Tensor): + if val.numel() <= self._tensor_numel_threshold: + candidates.append(name) + seen.add(name) + else: + logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}") + except Exception: + continue + + if len(candidates) >= max_attrs: + break + + self._auto_detected_attrs = candidates + logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}") + return self._auto_detected_attrs + + def _is_readonly_property(self, base_obj, attr_name: str) -> bool: + try: + cls = type(base_obj) + descriptor = getattr(cls, attr_name, None) + if isinstance(descriptor, property): + return descriptor.fset is None + if hasattr(descriptor, "__set__") is False and descriptor is not None: + return False + except Exception: + pass + return False + + def _clone_mutable_attrs(self, base, local): + attrs_to_clone = list(self._mutable_attrs) + attrs_to_clone.extend(self._autodetect_mutables()) + + EXCLUDE_ATTRS = { + "components", + } + + for attr in attrs_to_clone: + if attr in EXCLUDE_ATTRS: + logger.debug(f"Skipping excluded attr '{attr}'") + continue + if not hasattr(base, attr): + continue + if self._is_readonly_property(base, attr): + logger.debug(f"Skipping read-only property '{attr}'") + continue + + try: + val = getattr(base, attr) + except Exception as e: + logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}") + continue + + try: + if isinstance(val, dict): + setattr(local, attr, dict(val)) + elif isinstance(val, (list, tuple, set)): + setattr(local, attr, list(val)) + elif isinstance(val, bytearray): + setattr(local, attr, bytearray(val)) + else: + # small tensors or atomic values + if isinstance(val, torch.Tensor): + if val.numel() <= self._tensor_numel_threshold: + setattr(local, attr, val.clone()) + else: + # don't clone big tensors, keep reference + setattr(local, attr, val) + else: + try: + setattr(local, attr, copy.copy(val)) + except Exception: + setattr(local, attr, val) + except (AttributeError, TypeError) as e: + logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}") + continue + except Exception as e: + logger.debug(f"Unexpected error cloning attribute '{attr}': {e}") + continue + + def _is_tokenizer_component(self, component) -> bool: + if component is None: + return False + + tokenizer_methods = ["encode", "decode", "tokenize", "__call__"] + has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods) + + class_name = component.__class__.__name__.lower() + has_tokenizer_in_name = "tokenizer" in class_name + + tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"] + has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs) + + return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs) + + def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs): + local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device) + + try: + local_pipe = copy.copy(self._base) + except Exception as e: + logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).") + local_pipe = copy.deepcopy(self._base) + + if local_scheduler is not None: + try: + timesteps, num_steps, configured_scheduler = async_retrieve_timesteps( + local_scheduler.scheduler, + num_inference_steps=num_inference_steps, + device=device, + return_scheduler=True, + **{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]}, + ) + + final_scheduler = BaseAsyncScheduler(configured_scheduler) + setattr(local_pipe, "scheduler", final_scheduler) + except Exception: + logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.") + + self._clone_mutable_attrs(self._base, local_pipe) + + # 4) wrap tokenizers on the local pipe with the lock wrapper + tokenizer_wrappers = {} # name -> original_tokenizer + try: + # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...) + for name in dir(local_pipe): + if "tokenizer" in name and not name.startswith("_"): + tok = getattr(local_pipe, name, None) + if tok is not None and self._is_tokenizer_component(tok): + tokenizer_wrappers[name] = tok + setattr( + local_pipe, + name, + lambda *args, tok=tok, **kwargs: safe_tokenize( + tok, *args, lock=self._tokenizer_lock, **kwargs + ), + ) + + # b) wrap tokenizers in components dict + if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): + for key, val in local_pipe.components.items(): + if val is None: + continue + + if self._is_tokenizer_component(val): + tokenizer_wrappers[f"components[{key}]"] = val + local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize( + tokenizer, *args, lock=self._tokenizer_lock, **kwargs + ) + + except Exception as e: + logger.debug(f"Tokenizer wrapping step encountered an error: {e}") + + result = None + cm = getattr(local_pipe, "model_cpu_offload_context", None) + try: + if callable(cm): + try: + with cm(): + result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) + except TypeError: + # cm might be a context manager instance rather than callable + try: + with cm: + result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) + except Exception as e: + logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.") + result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) + else: + # no offload context available — call directly + result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) + + return result + + finally: + try: + for name, tok in tokenizer_wrappers.items(): + if name.startswith("components["): + key = name[len("components[") : -1] + local_pipe.components[key] = tok + else: + setattr(local_pipe, name, tok) + except Exception as e: + logger.debug(f"Error restoring wrapped tokenizers: {e}") diff --git a/examples/server-async/utils/scheduler.py b/examples/server-async/utils/scheduler.py new file mode 100644 index 000000000000..86d47cac6154 --- /dev/null +++ b/examples/server-async/utils/scheduler.py @@ -0,0 +1,141 @@ +import copy +import inspect +from typing import Any, List, Optional, Union + +import torch + + +class BaseAsyncScheduler: + def __init__(self, scheduler: Any): + self.scheduler = scheduler + + def __getattr__(self, name: str): + if hasattr(self.scheduler, name): + return getattr(self.scheduler, name) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value): + if name == "scheduler": + super().__setattr__(name, value) + else: + if hasattr(self, "scheduler") and hasattr(self.scheduler, name): + setattr(self.scheduler, name, value) + else: + super().__setattr__(name, value) + + def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs): + local = copy.deepcopy(self.scheduler) + local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs) + cloned = self.__class__(local) + return cloned + + def __repr__(self): + return f"BaseAsyncScheduler({repr(self.scheduler)})" + + def __str__(self): + return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}" + + +def async_retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. + Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Backwards compatible: by default the function behaves exactly as before and returns + (timesteps_tensor, num_inference_steps) + + If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed + scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`) + or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return: + (timesteps_tensor, num_inference_steps, scheduler_in_use) + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Optional kwargs: + return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use) + where `scheduler_in_use` is a scheduler instance that already has timesteps set. + This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler. + + Returns: + `(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or + `(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`. + """ + # pop our optional control kwarg (keeps compatibility) + return_scheduler = bool(kwargs.pop("return_scheduler", False)) + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + + # choose scheduler to call set_timesteps on + scheduler_in_use = scheduler + if return_scheduler: + # Do not mutate the provided scheduler: prefer to clone if possible + if hasattr(scheduler, "clone_for_request"): + try: + # clone_for_request may accept num_inference_steps or other kwargs; be permissive + scheduler_in_use = scheduler.clone_for_request( + num_inference_steps=num_inference_steps or 0, device=device + ) + except Exception: + scheduler_in_use = copy.deepcopy(scheduler) + else: + # fallback deepcopy (scheduler tends to be smallish - acceptable) + scheduler_in_use = copy.deepcopy(scheduler) + + # helper to test if set_timesteps supports a particular kwarg + def _accepts(param_name: str) -> bool: + try: + return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys()) + except (ValueError, TypeError): + # if signature introspection fails, be permissive and attempt the call later + return False + + # now call set_timesteps on the chosen scheduler_in_use (may be original or clone) + if timesteps is not None: + accepts_timesteps = _accepts("timesteps") + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps_out = scheduler_in_use.timesteps + num_inference_steps = len(timesteps_out) + elif sigmas is not None: + accept_sigmas = _accepts("sigmas") + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps_out = scheduler_in_use.timesteps + num_inference_steps = len(timesteps_out) + else: + # default path + scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps_out = scheduler_in_use.timesteps + + if return_scheduler: + return timesteps_out, num_inference_steps, scheduler_in_use + return timesteps_out, num_inference_steps diff --git a/examples/server-async/utils/utils.py b/examples/server-async/utils/utils.py new file mode 100644 index 000000000000..9f943305126c --- /dev/null +++ b/examples/server-async/utils/utils.py @@ -0,0 +1,48 @@ +import gc +import logging +import os +import tempfile +import uuid + +import torch + + +logger = logging.getLogger(__name__) + + +class Utils: + def __init__(self, host: str = "0.0.0.0", port: int = 8500): + self.service_url = f"http://{host}:{port}" + self.image_dir = os.path.join(tempfile.gettempdir(), "images") + if not os.path.exists(self.image_dir): + os.makedirs(self.image_dir) + + self.video_dir = os.path.join(tempfile.gettempdir(), "videos") + if not os.path.exists(self.video_dir): + os.makedirs(self.video_dir) + + def save_image(self, image): + if hasattr(image, "to"): + try: + image = image.to("cpu") + except Exception: + pass + + if isinstance(image, torch.Tensor): + from torchvision import transforms + + to_pil = transforms.ToPILImage() + image = to_pil(image.squeeze(0).clamp(0, 1)) + + filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png" + image_path = os.path.join(self.image_dir, filename) + logger.info(f"Saving image to {image_path}") + + image.save(image_path, format="PNG", optimize=True) + + del image + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return os.path.join(self.service_url, "images", filename)