Skip to content

Commit 0beab1c

Browse files
Update examples/server-async
1 parent e676b34 commit 0beab1c

File tree

6 files changed

+123
-114
lines changed

6 files changed

+123
-114
lines changed

examples/server-async/DiffusersServer/Pipelines.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# from https://github.com/F4k3r22/DiffusersServer/blob/main/DiffusersServer/Pipelines.py
1+
# Pipelines.py
22

33
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
44
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
@@ -18,22 +18,12 @@ class TextToImageInput(BaseModel):
1818

1919
class TextToImagePipelineSD3:
2020
def __init__(self, model_path: str | None = None):
21-
"""
22-
Inicialización de la clase con la ruta del modelo.
23-
Si no se proporciona, se obtiene de la variable de entorno.
24-
"""
2521
self.model_path = model_path or os.getenv("MODEL_PATH")
2622
self.pipeline: StableDiffusion3Pipeline = None
2723
self.device: str = None
2824

2925
def start(self):
30-
"""
31-
Inicia el pipeline cargando el modelo en CUDA o MPS según esté disponible.
32-
Se utiliza la ruta del modelo definida en el __init__ y se asigna un valor predeterminado
33-
en función del dispositivo disponible si no se definió previamente.
34-
"""
3526
if torch.cuda.is_available():
36-
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
3727
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
3828
logger.info("Loading CUDA")
3929
self.device = "cuda"
@@ -42,7 +32,6 @@ def start(self):
4232
torch_dtype=torch.float16,
4333
).to(device=self.device)
4434
elif torch.backends.mps.is_available():
45-
# Si no se definió model_path, se asigna el valor por defecto para MPS.
4635
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
4736
logger.info("Loading MPS for Mac M Series")
4837
self.device = "mps"
@@ -55,18 +44,13 @@ def start(self):
5544

5645
class TextToImagePipelineFlux:
5746
def __init__(self, model_path: str | None = None, low_vram: bool = False):
58-
"""
59-
Inicialización de la clase con la ruta del modelo.
60-
Si no se proporciona, se obtiene de la variable de entorno.
61-
"""
6247
self.model_path = model_path or os.getenv("MODEL_PATH")
6348
self.pipeline: FluxPipeline = None
6449
self.device: str = None
6550
self.low_vram = low_vram
6651

6752
def start(self):
6853
if torch.cuda.is_available():
69-
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
7054
model_path = self.model_path or "black-forest-labs/FLUX.1-schnell"
7155
logger.info("Loading CUDA")
7256
self.device = "cuda"
@@ -79,7 +63,6 @@ def start(self):
7963
else:
8064
pass
8165
elif torch.backends.mps.is_available():
82-
# Si no se definió model_path, se asigna el valor por defecto para MPS.
8366
model_path = self.model_path or "black-forest-labs/FLUX.1-schnell"
8467
logger.info("Loading MPS for Mac M Series")
8568
self.device = "mps"
@@ -92,17 +75,12 @@ def start(self):
9275

9376
class TextToImagePipelineSD:
9477
def __init__(self, model_path: str | None = None):
95-
"""
96-
Inicialización de la clase con la ruta del modelo.
97-
Si no se proporciona, se obtiene de la variable de entorno.
98-
"""
9978
self.model_path = model_path or os.getenv("MODEL_PATH")
10079
self.pipeline: StableDiffusionPipeline = None
10180
self.device: str = None
10281

10382
def start(self):
10483
if torch.cuda.is_available():
105-
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
10684
model_path = self.model_path or "sd-legacy/stable-diffusion-v1-5"
10785
logger.info("Loading CUDA")
10886
self.device = "cuda"
@@ -111,7 +89,6 @@ def start(self):
11189
torch_dtype=torch.float16,
11290
).to(device=self.device)
11391
elif torch.backends.mps.is_available():
114-
# Si no se definió model_path, se asigna el valor por defecto para MPS.
11592
model_path = self.model_path or "sd-legacy/stable-diffusion-v1-5"
11693
logger.info("Loading MPS for Mac M Series")
11794
self.device = "mps"

examples/server-async/DiffusersServer/create_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# from https://github.com/F4k3r22/DiffusersServer/blob/main/DiffusersServer/create_server.py
1+
# create_server.py
22

33
from .Pipelines import *
44
from .serverasync import *

examples/server-async/DiffusersServer/serverasync.py

Lines changed: 117 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
# from https://github.com/F4k3r22/DiffusersServer/blob/main/DiffusersServer/serverasync.py
2-
3-
from fastapi import FastAPI, HTTPException, status
1+
from fastapi import FastAPI, HTTPException, Request
42
from fastapi.responses import FileResponse
53
from fastapi.middleware.cors import CORSMiddleware
64
from fastapi.concurrency import run_in_threadpool
@@ -22,6 +20,8 @@
2220
from typing import Optional, Dict, Any, Type
2321
from dataclasses import dataclass, field
2422
from typing import List
23+
from contextlib import asynccontextmanager
24+
import asyncio
2525

2626
@dataclass
2727
class PresetModels:
@@ -114,19 +114,108 @@ def save_video(self, video, fps):
114114

115115
@dataclass
116116
class ServerConfigModels:
117-
model: str = 'stabilityai/stable-diffusion-3-medium'
117+
model: str = 'stabilityai/stable-diffusion-3-medium'
118118
type_models: str = 't2im'
119119
custom_model : bool = False
120120
constructor_pipeline: Optional[Type] = None
121-
custom_pipeline: Optional[Type] = None
121+
custom_pipeline: Optional[Type] = None
122122
components: Optional[Dict[str, Any]] = None
123123
api_name: Optional[str] = 'custom_api'
124124
torch_dtype: Optional[torch.dtype] = None
125125
host: str = '0.0.0.0'
126126
port: int = 8500
127127

128128
def create_app_fastapi(config: ServerConfigModels) -> FastAPI:
129-
app = FastAPI()
129+
130+
server_config = config or ServerConfigModels()
131+
132+
@asynccontextmanager
133+
async def lifespan(app: FastAPI):
134+
logging.basicConfig(level=logging.INFO)
135+
app.state.logger = logging.getLogger("diffusers-server")
136+
137+
app.state.total_requests = 0
138+
app.state.active_inferences = 0
139+
app.state.metrics_lock = asyncio.Lock()
140+
app.state.metrics_task = None
141+
142+
app.state.utils_app = Utils(
143+
host=server_config.host,
144+
port=server_config.port,
145+
)
146+
147+
async def metrics_loop():
148+
try:
149+
while True:
150+
async with app.state.metrics_lock:
151+
total = app.state.total_requests
152+
active = app.state.active_inferences
153+
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
154+
await asyncio.sleep(5)
155+
except asyncio.CancelledError:
156+
app.state.logger.info("Metrics loop cancelled")
157+
raise
158+
159+
app.state.metrics_task = asyncio.create_task(metrics_loop())
160+
161+
try:
162+
yield
163+
finally:
164+
# 🔻 shutdown
165+
task = app.state.metrics_task
166+
if task:
167+
task.cancel()
168+
try:
169+
await task
170+
except asyncio.CancelledError:
171+
pass
172+
173+
try:
174+
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
175+
if callable(stop_fn):
176+
await run_in_threadpool(stop_fn)
177+
except Exception as e:
178+
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
179+
180+
app.state.logger.info("Lifespan shutdown complete")
181+
182+
183+
184+
app = FastAPI(lifespan=lifespan)
185+
186+
logger = logging.getLogger("DiffusersServer.Pipelines")
187+
188+
if server_config.custom_model:
189+
if server_config.constructor_pipeline is None:
190+
raise ValueError("constructor_pipeline cannot be None - a valid pipeline constructor is required")
191+
192+
initializer = server_config.constructor_pipeline(
193+
model_path=server_config.model,
194+
pipeline=server_config.custom_pipeline,
195+
torch_dtype=server_config.torch_dtype,
196+
components=server_config.components,
197+
)
198+
model_pipeline = initializer.start()
199+
request_pipe = None
200+
pipeline_lock = threading.Lock()
201+
202+
else:
203+
initializer = ModelPipelineInitializer(
204+
model=server_config.model,
205+
type_models=server_config.type_models,
206+
)
207+
model_pipeline = initializer.initialize_pipeline()
208+
model_pipeline.start()
209+
210+
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
211+
pipeline_lock = threading.Lock()
212+
213+
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
214+
215+
app.state.MODEL_INITIALIZER = initializer
216+
app.state.MODEL_PIPELINE = model_pipeline
217+
app.state.REQUEST_PIPE = request_pipe
218+
app.state.PIPELINE_LOCK = pipeline_lock
130219

131220
class JSONBodyQueryAPI(BaseModel):
132221
model : str | None = None
@@ -135,54 +224,12 @@ class JSONBodyQueryAPI(BaseModel):
135224
num_inference_steps : int = 28
136225
num_images_per_prompt : int = 1
137226

138-
logging.basicConfig(level=logging.INFO)
139-
global logger
140-
logger = logging.getLogger(__name__)
141-
142-
server_config = config or ServerConfigModels()
143-
app.state.SERVER_CONFIG = server_config
144-
145-
global utils_app
146-
147-
utils_app = Utils(host=server_config.host, port=server_config.port)
148-
149-
logger.info(f"Inicializando pipeline para el modelo: {server_config.model}")
150-
try:
151-
if server_config.custom_model:
152-
if server_config.constructor_pipeline is None:
153-
raise ValueError("constructor_pipeline cannot be None - a valid pipeline constructor is required")
154-
initializer = server_config.constructor_pipeline(
155-
model_path=server_config.model,
156-
pipeline=server_config.custom_pipeline,
157-
torch_dtype=server_config.torch_dtype,
158-
components=server_config.components,
159-
)
160-
model_pipeline = initializer.start()
161-
app.state.CUSTOM_PIPELINE = server_config.custom_pipeline
162-
app.state.MODEL_PIPELINE = model_pipeline
163-
app.state.MODEL_INITIALIZER = initializer
164-
logger.info(f"Pipeline personalizado inicializado. Tipo: {type(model_pipeline)}")
165-
else:
166-
initializer = ModelPipelineInitializer(
167-
model=server_config.model,
168-
type_models=server_config.type_models,
169-
)
170-
model_pipeline = initializer.initialize_pipeline()
171-
model_pipeline.start()
172-
173-
app.state.REQUEST_PIPE = RequestScopedPipeline(model_pipeline.pipeline)
174-
175-
# Lock for concurrency
176-
pipeline_lock = threading.Lock()
177-
178-
app.state.MODEL_PIPELINE = model_pipeline
179-
app.state.PIPELINE_LOCK = pipeline_lock
180-
app.state.MODEL_INITIALIZER = initializer
181-
182-
logger.info("Pipeline initialized and ready to receive requests")
183-
except Exception as e:
184-
logger.error(f"Error initializing pipeline: {e}")
185-
raise
227+
@app.middleware("http")
228+
async def count_requests_middleware(request: Request, call_next):
229+
async with app.state.metrics_lock:
230+
app.state.total_requests += 1
231+
response = await call_next(request)
232+
return response
186233

187234

188235
@app.get("/")
@@ -196,14 +243,16 @@ async def api(json: JSONBodyQueryAPI):
196243
num_steps = json.num_inference_steps
197244
num_images_per_prompt = json.num_images_per_prompt
198245

199-
wrapper = app.state.MODEL_PIPELINE
246+
wrapper = app.state.MODEL_PIPELINE
200247
initializer = app.state.MODEL_INITIALIZER
201248

249+
utils_app = app.state.utils_app
250+
202251

203252
if not wrapper or not wrapper.pipeline:
204-
raise HTTPException(500, "Modelo no inicializado correctamente")
253+
raise HTTPException(500, "Model not initialized correctly")
205254
if not prompt.strip():
206-
raise HTTPException(400, "No se proporcionó prompt")
255+
raise HTTPException(400, "No prompt provided")
207256

208257
def make_generator():
209258
g = torch.Generator(device=initializer.device)
@@ -212,9 +261,6 @@ def make_generator():
212261
req_pipe = app.state.REQUEST_PIPE
213262

214263
def infer():
215-
# This is called that because the RequestScoped Pipeline already internally
216-
# handles everything necessary for inference and only the
217-
# model pipeline needs to be passed, for example StableDiffusion3Pipeline
218264
gen = make_generator()
219265
return req_pipe.generate(
220266
prompt=prompt,
@@ -226,14 +272,22 @@ def infer():
226272
)
227273

228274
try:
275+
async with app.state.metrics_lock:
276+
app.state.active_inferences += 1
277+
229278
output = await run_in_threadpool(infer)
230279

280+
async with app.state.metrics_lock:
281+
app.state.active_inferences = max(0, app.state.active_inferences - 1)
282+
231283
urls = [utils_app.save_image(img) for img in output.images]
232284
return {"response": urls}
233285

234286
except Exception as e:
235-
logger.error(f"Error durante la inferencia: {e}")
236-
raise HTTPException(500, f"Error en procesamiento: {e}")
287+
async with app.state.metrics_lock:
288+
app.state.active_inferences = max(0, app.state.active_inferences - 1)
289+
logger.error(f"Error during inference: {e}")
290+
raise HTTPException(500, f"Error in processing: {e}")
237291

238292
finally:
239293
import gc; gc.collect()
@@ -243,6 +297,7 @@ def infer():
243297

244298
@app.get("/images/{filename}")
245299
async def serve_image(filename: str):
300+
utils_app = app.state.utils_app
246301
file_path = os.path.join(utils_app.image_dir, filename)
247302
if not os.path.isfile(file_path):
248303
raise HTTPException(status_code=404, detail="Image not found")

examples/server-async/DiffusersServer/superpipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# from https://github.com/F4k3r22/DiffusersServer/blob/main/DiffusersServer/superpipeline.py
2-
31
from diffusers.pipelines import *
42
from diffusers import *
53
import torch

0 commit comments

Comments
 (0)