Skip to content

Commit 7c4f883

Browse files
Apply style fixes
1 parent 5598557 commit 7c4f883

File tree

7 files changed

+122
-87
lines changed

7 files changed

+122
-87
lines changed

examples/server-async/Pipelines.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
2-
import torch
3-
import os
41
import logging
5-
from pydantic import BaseModel
6-
from dataclasses import dataclass, field
2+
import os
3+
from dataclasses import dataclass, field
74
from typing import List
85

6+
import torch
7+
from pydantic import BaseModel
8+
9+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
10+
11+
912
logger = logging.getLogger(__name__)
1013

14+
1115
class TextToImageInput(BaseModel):
1216
model: str
1317
prompt: str
@@ -17,8 +21,15 @@ class TextToImageInput(BaseModel):
1721

1822
@dataclass
1923
class PresetModels:
20-
SD3: List[str] = field(default_factory=lambda: ['stabilityai/stable-diffusion-3-medium'])
21-
SD3_5: List[str] = field(default_factory=lambda: ['stabilityai/stable-diffusion-3.5-large', 'stabilityai/stable-diffusion-3.5-large-turbo', 'stabilityai/stable-diffusion-3.5-medium'])
24+
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
25+
SD3_5: List[str] = field(
26+
default_factory=lambda: [
27+
"stabilityai/stable-diffusion-3.5-large",
28+
"stabilityai/stable-diffusion-3.5-large-turbo",
29+
"stabilityai/stable-diffusion-3.5-medium",
30+
]
31+
)
32+
2233

2334
class TextToImagePipelineSD3:
2435
def __init__(self, model_path: str | None = None):
@@ -46,8 +57,9 @@ def start(self):
4657
else:
4758
raise Exception("No CUDA or MPS device available")
4859

60+
4961
class ModelPipelineInitializer:
50-
def __init__(self, model: str = '', type_models: str = 't2im'):
62+
def __init__(self, model: str = "", type_models: str = "t2im"):
5163
self.model = model
5264
self.type_models = type_models
5365
self.pipeline = None
@@ -68,12 +80,12 @@ def initialize_pipeline(self):
6880
self.model_type = "SD3_5"
6981

7082
# Create appropriate pipeline based on model type and type_models
71-
if self.type_models == 't2im':
83+
if self.type_models == "t2im":
7284
if self.model_type in ["SD3", "SD3_5"]:
7385
self.pipeline = TextToImagePipelineSD3(self.model)
7486
else:
7587
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
76-
elif self.type_models == 't2v':
88+
elif self.type_models == "t2v":
7789
raise ValueError(f"Unsupported type_models: {self.type_models}")
7890

79-
return self.pipeline
91+
return self.pipeline

examples/server-async/serverasync.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,45 @@
1-
from fastapi import FastAPI, HTTPException, Request
2-
from fastapi.responses import FileResponse
3-
from fastapi.middleware.cors import CORSMiddleware
4-
from fastapi.concurrency import run_in_threadpool
5-
from pydantic import BaseModel
6-
from Pipelines import ModelPipelineInitializer
7-
from utils import Utils, RequestScopedPipeline
1+
import asyncio
2+
import gc
83
import logging
9-
import random
10-
from dataclasses import dataclass
114
import os
12-
import torch
5+
import random
136
import threading
14-
import gc
15-
from typing import Optional, Dict, Any, Type
167
from contextlib import asynccontextmanager
17-
import asyncio
8+
from dataclasses import dataclass
9+
from typing import Any, Dict, Optional, Type
10+
11+
import torch
12+
from fastapi import FastAPI, HTTPException, Request
13+
from fastapi.concurrency import run_in_threadpool
14+
from fastapi.middleware.cors import CORSMiddleware
15+
from fastapi.responses import FileResponse
16+
from Pipelines import ModelPipelineInitializer
17+
from pydantic import BaseModel
18+
19+
from utils import RequestScopedPipeline, Utils
1820

1921

2022
@dataclass
2123
class ServerConfigModels:
22-
model: str = 'stabilityai/stable-diffusion-3.5-medium'
23-
type_models: str = 't2im'
24+
model: str = "stabilityai/stable-diffusion-3.5-medium"
25+
type_models: str = "t2im"
2426
constructor_pipeline: Optional[Type] = None
25-
custom_pipeline: Optional[Type] = None
27+
custom_pipeline: Optional[Type] = None
2628
components: Optional[Dict[str, Any]] = None
2729
torch_dtype: Optional[torch.dtype] = None
28-
host: str = '0.0.0.0'
30+
host: str = "0.0.0.0"
2931
port: int = 8500
3032

33+
3134
server_config = ServerConfigModels()
3235

36+
3337
@asynccontextmanager
3438
async def lifespan(app: FastAPI):
3539
logging.basicConfig(level=logging.INFO)
3640
app.state.logger = logging.getLogger("diffusers-server")
37-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True'
38-
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
41+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
42+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
3943

4044
app.state.total_requests = 0
4145
app.state.active_inferences = 0
@@ -81,12 +85,12 @@ async def metrics_loop():
8185

8286
app.state.logger.info("Lifespan shutdown complete")
8387

88+
8489
app = FastAPI(lifespan=lifespan)
8590

8691
logger = logging.getLogger("DiffusersServer.Pipelines")
8792

8893

89-
9094
initializer = ModelPipelineInitializer(
9195
model=server_config.model,
9296
type_models=server_config.type_models,
@@ -104,12 +108,14 @@ async def metrics_loop():
104108
app.state.REQUEST_PIPE = request_pipe
105109
app.state.PIPELINE_LOCK = pipeline_lock
106110

111+
107112
class JSONBodyQueryAPI(BaseModel):
108-
model : str | None = None
109-
prompt : str
110-
negative_prompt : str | None = None
111-
num_inference_steps : int = 28
112-
num_images_per_prompt : int = 1
113+
model: str | None = None
114+
prompt: str
115+
negative_prompt: str | None = None
116+
num_inference_steps: int = 28
117+
num_images_per_prompt: int = 1
118+
113119

114120
@app.middleware("http")
115121
async def count_requests_middleware(request: Request, call_next):
@@ -123,25 +129,24 @@ async def count_requests_middleware(request: Request, call_next):
123129
async def root():
124130
return {"message": "Welcome to the Diffusers Server"}
125131

132+
126133
@app.post("/api/diffusers/inference")
127134
async def api(json: JSONBodyQueryAPI):
128-
prompt = json.prompt
129-
negative_prompt = json.negative_prompt or ""
130-
num_steps = json.num_inference_steps
135+
prompt = json.prompt
136+
negative_prompt = json.negative_prompt or ""
137+
num_steps = json.num_inference_steps
131138
num_images_per_prompt = json.num_images_per_prompt
132139

133-
wrapper = app.state.MODEL_PIPELINE
140+
wrapper = app.state.MODEL_PIPELINE
134141
initializer = app.state.MODEL_INITIALIZER
135142

136143
utils_app = app.state.utils_app
137144

138-
139145
if not wrapper or not wrapper.pipeline:
140146
raise HTTPException(500, "Model not initialized correctly")
141147
if not prompt.strip():
142148
raise HTTPException(400, "No prompt provided")
143149

144-
145150
def make_generator():
146151
g = torch.Generator(device=initializer.device)
147152
return g.manual_seed(random.randint(0, 10_000_000))
@@ -168,7 +173,7 @@ def infer():
168173

169174
async with app.state.metrics_lock:
170175
app.state.active_inferences = max(0, app.state.active_inferences - 1)
171-
176+
172177
urls = [utils_app.save_image(img) for img in output.images]
173178
return {"response": urls}
174179

@@ -195,27 +200,25 @@ async def serve_image(filename: str):
195200
raise HTTPException(status_code=404, detail="Image not found")
196201
return FileResponse(file_path, media_type="image/png")
197202

203+
198204
@app.get("/api/status")
199205
async def get_status():
200206
memory_info = {}
201207
if torch.cuda.is_available():
202208
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
203-
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
209+
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
204210
memory_info = {
205211
"memory_allocated_gb": round(memory_allocated, 2),
206212
"memory_reserved_gb": round(memory_reserved, 2),
207-
"device": torch.cuda.get_device_name(0)
213+
"device": torch.cuda.get_device_name(0),
208214
}
209215

210-
return {
211-
"current_model" : server_config.model,
212-
"type_models" : server_config.type_models,
213-
"memory" : memory_info}
214-
216+
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
217+
215218

216219
app.add_middleware(
217220
CORSMiddleware,
218-
allow_origins=["*"],
221+
allow_origins=["*"],
219222
allow_credentials=True,
220223
allow_methods=["*"],
221224
allow_headers=["*"],
@@ -224,4 +227,4 @@ async def get_status():
224227
if __name__ == "__main__":
225228
import uvicorn
226229

227-
uvicorn.run(app, host=server_config.host, port=server_config.port)
230+
uvicorn.run(app, host=server_config.host, port=server_config.port)

examples/server-async/test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import os
22
import time
33
import urllib.parse
4+
45
import requests
56

7+
68
SERVER_URL = "http://localhost:8500/api/diffusers/inference"
79
BASE_URL = "http://localhost:8500"
810
DOWNLOAD_FOLDER = "generated_images"
911
WAIT_BEFORE_DOWNLOAD = 2 # seconds
1012

1113
os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
1214

15+
1316
def save_from_url(url: str) -> str:
1417
"""Download the given URL (relative or absolute) and save it locally."""
1518
if url.startswith("/"):
@@ -24,11 +27,12 @@ def save_from_url(url: str) -> str:
2427
f.write(resp.content)
2528
return path
2629

30+
2731
def main():
2832
payload = {
2933
"prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
3034
"num_inference_steps": 30,
31-
"num_images_per_prompt": 1
35+
"num_images_per_prompt": 1,
3236
}
3337

3438
print("Sending request...")
@@ -56,5 +60,6 @@ def main():
5660
except Exception as e:
5761
print(f"Error downloading {u}: {e}")
5862

63+
5964
if __name__ == "__main__":
6065
main()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .requestscopedpipeline import RequestScopedPipeline
2-
from .utils import Utils
2+
from .utils import Utils

0 commit comments

Comments
 (0)