1- from fastapi import FastAPI , HTTPException , Request
2- from fastapi .responses import FileResponse
3- from fastapi .middleware .cors import CORSMiddleware
4- from fastapi .concurrency import run_in_threadpool
5- from pydantic import BaseModel
6- from Pipelines import ModelPipelineInitializer
7- from utils import Utils , RequestScopedPipeline
1+ import asyncio
2+ import gc
83import logging
9- import random
10- from dataclasses import dataclass
114import os
12- import torch
5+ import random
136import threading
14- import gc
15- from typing import Optional , Dict , Any , Type
167from contextlib import asynccontextmanager
17- import asyncio
8+ from dataclasses import dataclass
9+ from typing import Any , Dict , Optional , Type
10+
11+ import torch
12+ from fastapi import FastAPI , HTTPException , Request
13+ from fastapi .concurrency import run_in_threadpool
14+ from fastapi .middleware .cors import CORSMiddleware
15+ from fastapi .responses import FileResponse
16+ from Pipelines import ModelPipelineInitializer
17+ from pydantic import BaseModel
18+
19+ from utils import RequestScopedPipeline , Utils
1820
1921
2022@dataclass
2123class ServerConfigModels :
22- model : str = ' stabilityai/stable-diffusion-3.5-medium'
23- type_models : str = ' t2im'
24+ model : str = " stabilityai/stable-diffusion-3.5-medium"
25+ type_models : str = " t2im"
2426 constructor_pipeline : Optional [Type ] = None
25- custom_pipeline : Optional [Type ] = None
27+ custom_pipeline : Optional [Type ] = None
2628 components : Optional [Dict [str , Any ]] = None
2729 torch_dtype : Optional [torch .dtype ] = None
28- host : str = ' 0.0.0.0'
30+ host : str = " 0.0.0.0"
2931 port : int = 8500
3032
33+
3134server_config = ServerConfigModels ()
3235
36+
3337@asynccontextmanager
3438async def lifespan (app : FastAPI ):
3539 logging .basicConfig (level = logging .INFO )
3640 app .state .logger = logging .getLogger ("diffusers-server" )
37- os .environ [' PYTORCH_CUDA_ALLOC_CONF' ] = ' max_split_size_mb:128,expandable_segments:True'
38- os .environ [' CUDA_LAUNCH_BLOCKING' ] = '0'
41+ os .environ [" PYTORCH_CUDA_ALLOC_CONF" ] = " max_split_size_mb:128,expandable_segments:True"
42+ os .environ [" CUDA_LAUNCH_BLOCKING" ] = "0"
3943
4044 app .state .total_requests = 0
4145 app .state .active_inferences = 0
@@ -81,12 +85,12 @@ async def metrics_loop():
8185
8286 app .state .logger .info ("Lifespan shutdown complete" )
8387
88+
8489app = FastAPI (lifespan = lifespan )
8590
8691logger = logging .getLogger ("DiffusersServer.Pipelines" )
8792
8893
89-
9094initializer = ModelPipelineInitializer (
9195 model = server_config .model ,
9296 type_models = server_config .type_models ,
@@ -104,12 +108,14 @@ async def metrics_loop():
104108app .state .REQUEST_PIPE = request_pipe
105109app .state .PIPELINE_LOCK = pipeline_lock
106110
111+
107112class JSONBodyQueryAPI (BaseModel ):
108- model : str | None = None
109- prompt : str
110- negative_prompt : str | None = None
111- num_inference_steps : int = 28
112- num_images_per_prompt : int = 1
113+ model : str | None = None
114+ prompt : str
115+ negative_prompt : str | None = None
116+ num_inference_steps : int = 28
117+ num_images_per_prompt : int = 1
118+
113119
114120@app .middleware ("http" )
115121async def count_requests_middleware (request : Request , call_next ):
@@ -123,25 +129,24 @@ async def count_requests_middleware(request: Request, call_next):
123129async def root ():
124130 return {"message" : "Welcome to the Diffusers Server" }
125131
132+
126133@app .post ("/api/diffusers/inference" )
127134async def api (json : JSONBodyQueryAPI ):
128- prompt = json .prompt
129- negative_prompt = json .negative_prompt or ""
130- num_steps = json .num_inference_steps
135+ prompt = json .prompt
136+ negative_prompt = json .negative_prompt or ""
137+ num_steps = json .num_inference_steps
131138 num_images_per_prompt = json .num_images_per_prompt
132139
133- wrapper = app .state .MODEL_PIPELINE
140+ wrapper = app .state .MODEL_PIPELINE
134141 initializer = app .state .MODEL_INITIALIZER
135142
136143 utils_app = app .state .utils_app
137144
138-
139145 if not wrapper or not wrapper .pipeline :
140146 raise HTTPException (500 , "Model not initialized correctly" )
141147 if not prompt .strip ():
142148 raise HTTPException (400 , "No prompt provided" )
143149
144-
145150 def make_generator ():
146151 g = torch .Generator (device = initializer .device )
147152 return g .manual_seed (random .randint (0 , 10_000_000 ))
@@ -168,7 +173,7 @@ def infer():
168173
169174 async with app .state .metrics_lock :
170175 app .state .active_inferences = max (0 , app .state .active_inferences - 1 )
171-
176+
172177 urls = [utils_app .save_image (img ) for img in output .images ]
173178 return {"response" : urls }
174179
@@ -195,27 +200,25 @@ async def serve_image(filename: str):
195200 raise HTTPException (status_code = 404 , detail = "Image not found" )
196201 return FileResponse (file_path , media_type = "image/png" )
197202
203+
198204@app .get ("/api/status" )
199205async def get_status ():
200206 memory_info = {}
201207 if torch .cuda .is_available ():
202208 memory_allocated = torch .cuda .memory_allocated () / 1024 ** 3 # GB
203- memory_reserved = torch .cuda .memory_reserved () / 1024 ** 3 # GB
209+ memory_reserved = torch .cuda .memory_reserved () / 1024 ** 3 # GB
204210 memory_info = {
205211 "memory_allocated_gb" : round (memory_allocated , 2 ),
206212 "memory_reserved_gb" : round (memory_reserved , 2 ),
207- "device" : torch .cuda .get_device_name (0 )
213+ "device" : torch .cuda .get_device_name (0 ),
208214 }
209215
210- return {
211- "current_model" : server_config .model ,
212- "type_models" : server_config .type_models ,
213- "memory" : memory_info }
214-
216+ return {"current_model" : server_config .model , "type_models" : server_config .type_models , "memory" : memory_info }
217+
215218
216219app .add_middleware (
217220 CORSMiddleware ,
218- allow_origins = ["*" ],
221+ allow_origins = ["*" ],
219222 allow_credentials = True ,
220223 allow_methods = ["*" ],
221224 allow_headers = ["*" ],
@@ -224,4 +227,4 @@ async def get_status():
224227if __name__ == "__main__" :
225228 import uvicorn
226229
227- uvicorn .run (app , host = server_config .host , port = server_config .port )
230+ uvicorn .run (app , host = server_config .host , port = server_config .port )
0 commit comments