1- # from https://github.com/F4k3r22/DiffusersServer/blob/main/DiffusersServer/serverasync.py
2-
3- from fastapi import FastAPI , HTTPException , status
1+ from fastapi import FastAPI , HTTPException , Request
42from fastapi .responses import FileResponse
53from fastapi .middleware .cors import CORSMiddleware
64from fastapi .concurrency import run_in_threadpool
2220from typing import Optional , Dict , Any , Type
2321from dataclasses import dataclass , field
2422from typing import List
23+ from contextlib import asynccontextmanager
24+ import asyncio
2525
2626@dataclass
2727class PresetModels :
@@ -114,19 +114,108 @@ def save_video(self, video, fps):
114114
115115@dataclass
116116class ServerConfigModels :
117- model : str = 'stabilityai/stable-diffusion-3-medium'
117+ model : str = 'stabilityai/stable-diffusion-3-medium'
118118 type_models : str = 't2im'
119119 custom_model : bool = False
120120 constructor_pipeline : Optional [Type ] = None
121- custom_pipeline : Optional [Type ] = None
121+ custom_pipeline : Optional [Type ] = None
122122 components : Optional [Dict [str , Any ]] = None
123123 api_name : Optional [str ] = 'custom_api'
124124 torch_dtype : Optional [torch .dtype ] = None
125125 host : str = '0.0.0.0'
126126 port : int = 8500
127127
128128def create_app_fastapi (config : ServerConfigModels ) -> FastAPI :
129- app = FastAPI ()
129+
130+ server_config = config or ServerConfigModels ()
131+
132+ @asynccontextmanager
133+ async def lifespan (app : FastAPI ):
134+ logging .basicConfig (level = logging .INFO )
135+ app .state .logger = logging .getLogger ("diffusers-server" )
136+
137+ app .state .total_requests = 0
138+ app .state .active_inferences = 0
139+ app .state .metrics_lock = asyncio .Lock ()
140+ app .state .metrics_task = None
141+
142+ app .state .utils_app = Utils (
143+ host = server_config .host ,
144+ port = server_config .port ,
145+ )
146+
147+ async def metrics_loop ():
148+ try :
149+ while True :
150+ async with app .state .metrics_lock :
151+ total = app .state .total_requests
152+ active = app .state .active_inferences
153+ app .state .logger .info (f"[METRICS] total_requests={ total } active_inferences={ active } " )
154+ await asyncio .sleep (5 )
155+ except asyncio .CancelledError :
156+ app .state .logger .info ("Metrics loop cancelled" )
157+ raise
158+
159+ app .state .metrics_task = asyncio .create_task (metrics_loop ())
160+
161+ try :
162+ yield
163+ finally :
164+ # 🔻 shutdown
165+ task = app .state .metrics_task
166+ if task :
167+ task .cancel ()
168+ try :
169+ await task
170+ except asyncio .CancelledError :
171+ pass
172+
173+ try :
174+ stop_fn = getattr (model_pipeline , "stop" , None ) or getattr (model_pipeline , "close" , None )
175+ if callable (stop_fn ):
176+ await run_in_threadpool (stop_fn )
177+ except Exception as e :
178+ app .state .logger .warning (f"Error during pipeline shutdown: { e } " )
179+
180+ app .state .logger .info ("Lifespan shutdown complete" )
181+
182+
183+
184+ app = FastAPI (lifespan = lifespan )
185+
186+ logger = logging .getLogger ("DiffusersServer.Pipelines" )
187+
188+ if server_config .custom_model :
189+ if server_config .constructor_pipeline is None :
190+ raise ValueError ("constructor_pipeline cannot be None - a valid pipeline constructor is required" )
191+
192+ initializer = server_config .constructor_pipeline (
193+ model_path = server_config .model ,
194+ pipeline = server_config .custom_pipeline ,
195+ torch_dtype = server_config .torch_dtype ,
196+ components = server_config .components ,
197+ )
198+ model_pipeline = initializer .start ()
199+ request_pipe = None
200+ pipeline_lock = threading .Lock ()
201+
202+ else :
203+ initializer = ModelPipelineInitializer (
204+ model = server_config .model ,
205+ type_models = server_config .type_models ,
206+ )
207+ model_pipeline = initializer .initialize_pipeline ()
208+ model_pipeline .start ()
209+
210+ request_pipe = RequestScopedPipeline (model_pipeline .pipeline )
211+ pipeline_lock = threading .Lock ()
212+
213+ logger .info (f"Pipeline initialized and ready to receive requests (model ={ server_config .model } )" )
214+
215+ app .state .MODEL_INITIALIZER = initializer
216+ app .state .MODEL_PIPELINE = model_pipeline
217+ app .state .REQUEST_PIPE = request_pipe
218+ app .state .PIPELINE_LOCK = pipeline_lock
130219
131220 class JSONBodyQueryAPI (BaseModel ):
132221 model : str | None = None
@@ -135,54 +224,12 @@ class JSONBodyQueryAPI(BaseModel):
135224 num_inference_steps : int = 28
136225 num_images_per_prompt : int = 1
137226
138- logging .basicConfig (level = logging .INFO )
139- global logger
140- logger = logging .getLogger (__name__ )
141-
142- server_config = config or ServerConfigModels ()
143- app .state .SERVER_CONFIG = server_config
144-
145- global utils_app
146-
147- utils_app = Utils (host = server_config .host , port = server_config .port )
148-
149- logger .info (f"Inicializando pipeline para el modelo: { server_config .model } " )
150- try :
151- if server_config .custom_model :
152- if server_config .constructor_pipeline is None :
153- raise ValueError ("constructor_pipeline cannot be None - a valid pipeline constructor is required" )
154- initializer = server_config .constructor_pipeline (
155- model_path = server_config .model ,
156- pipeline = server_config .custom_pipeline ,
157- torch_dtype = server_config .torch_dtype ,
158- components = server_config .components ,
159- )
160- model_pipeline = initializer .start ()
161- app .state .CUSTOM_PIPELINE = server_config .custom_pipeline
162- app .state .MODEL_PIPELINE = model_pipeline
163- app .state .MODEL_INITIALIZER = initializer
164- logger .info (f"Pipeline personalizado inicializado. Tipo: { type (model_pipeline )} " )
165- else :
166- initializer = ModelPipelineInitializer (
167- model = server_config .model ,
168- type_models = server_config .type_models ,
169- )
170- model_pipeline = initializer .initialize_pipeline ()
171- model_pipeline .start ()
172-
173- app .state .REQUEST_PIPE = RequestScopedPipeline (model_pipeline .pipeline )
174-
175- # Lock for concurrency
176- pipeline_lock = threading .Lock ()
177-
178- app .state .MODEL_PIPELINE = model_pipeline
179- app .state .PIPELINE_LOCK = pipeline_lock
180- app .state .MODEL_INITIALIZER = initializer
181-
182- logger .info ("Pipeline initialized and ready to receive requests" )
183- except Exception as e :
184- logger .error (f"Error initializing pipeline: { e } " )
185- raise
227+ @app .middleware ("http" )
228+ async def count_requests_middleware (request : Request , call_next ):
229+ async with app .state .metrics_lock :
230+ app .state .total_requests += 1
231+ response = await call_next (request )
232+ return response
186233
187234
188235 @app .get ("/" )
@@ -196,14 +243,16 @@ async def api(json: JSONBodyQueryAPI):
196243 num_steps = json .num_inference_steps
197244 num_images_per_prompt = json .num_images_per_prompt
198245
199- wrapper = app .state .MODEL_PIPELINE
246+ wrapper = app .state .MODEL_PIPELINE
200247 initializer = app .state .MODEL_INITIALIZER
201248
249+ utils_app = app .state .utils_app
250+
202251
203252 if not wrapper or not wrapper .pipeline :
204- raise HTTPException (500 , "Modelo no inicializado correctamente " )
253+ raise HTTPException (500 , "Model not initialized correctly " )
205254 if not prompt .strip ():
206- raise HTTPException (400 , "No se proporcionó prompt" )
255+ raise HTTPException (400 , "No prompt provided " )
207256
208257 def make_generator ():
209258 g = torch .Generator (device = initializer .device )
@@ -212,9 +261,6 @@ def make_generator():
212261 req_pipe = app .state .REQUEST_PIPE
213262
214263 def infer ():
215- # This is called that because the RequestScoped Pipeline already internally
216- # handles everything necessary for inference and only the
217- # model pipeline needs to be passed, for example StableDiffusion3Pipeline
218264 gen = make_generator ()
219265 return req_pipe .generate (
220266 prompt = prompt ,
@@ -226,14 +272,22 @@ def infer():
226272 )
227273
228274 try :
275+ async with app .state .metrics_lock :
276+ app .state .active_inferences += 1
277+
229278 output = await run_in_threadpool (infer )
230279
280+ async with app .state .metrics_lock :
281+ app .state .active_inferences = max (0 , app .state .active_inferences - 1 )
282+
231283 urls = [utils_app .save_image (img ) for img in output .images ]
232284 return {"response" : urls }
233285
234286 except Exception as e :
235- logger .error (f"Error durante la inferencia: { e } " )
236- raise HTTPException (500 , f"Error en procesamiento: { e } " )
287+ async with app .state .metrics_lock :
288+ app .state .active_inferences = max (0 , app .state .active_inferences - 1 )
289+ logger .error (f"Error during inference: { e } " )
290+ raise HTTPException (500 , f"Error in processing: { e } " )
237291
238292 finally :
239293 import gc ; gc .collect ()
@@ -243,6 +297,7 @@ def infer():
243297
244298 @app .get ("/images/{filename}" )
245299 async def serve_image (filename : str ):
300+ utils_app = app .state .utils_app
246301 file_path = os .path .join (utils_app .image_dir , filename )
247302 if not os .path .isfile (file_path ):
248303 raise HTTPException (status_code = 404 , detail = "Image not found" )
0 commit comments