Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

## OpenAI Compatible `/v1/images/generations` Server

This is a concurrent, multithreaded solution for running a server that can generate images using the `diffusers` library. This examples uses the Stable Diffusion 3 pipeline, but you can use any pipeline that you would like by swapping out the model and pipeline to be the ones that you want to use.

### Installing Dependencies

Start by going to the base of the repo and installing it with:
``py
pip install .
```

The pipeline can then have its dependencies installed with:
```py
pip install -f requirements.txt
```

### Running the server

This server can be run with:
```py
python server.py
```
The server will be spun up at http://localhost:8000. You can `curl` this model with the following command:
```
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
```

### Upgrading Dependencies
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add some notes on the multi-threaded nature of this example and show some GPU utilization numbers with and without threading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'll add that in the next commit.


If you need to upgrade some dependencies, you can do that with either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). With `uv`, this looks like:
```
uv pip compile requirements.in -o requirements.txt
```

9 changes: 9 additions & 0 deletions examples/server/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
torch~=2.4.0
transformers==4.46.1
sentencepiece
aiohttp
py-consul
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
fastapi
uvicorn
124 changes: 124 additions & 0 deletions examples/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.10
# via -r requirements.in
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.6.2.post1
# via starlette
attrs==24.2.0
# via aiohttp
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
click==8.1.7
# via uvicorn
fastapi==0.115.3
# via -r requirements.in
filelock==3.16.1
# via
# huggingface-hub
# torch
# transformers
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.10.0
# via
# huggingface-hub
# torch
h11==0.14.0
# via uvicorn
huggingface-hub==0.26.1
# via
# tokenizers
# transformers
idna==3.10
# via
# anyio
# requests
# yarl
jinja2==3.1.4
# via torch
markupsafe==3.0.2
# via jinja2
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via
# aiohttp
# yarl
networkx==3.4.2
# via torch
numpy==2.1.2
# via transformers
packaging==24.1
# via
# huggingface-hub
# transformers
prometheus-client==0.21.0
# via
# -r requirements.in
# prometheus-fastapi-instrumentator
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
# via yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
# via fastapi
pydantic-core==2.23.4
# via pydantic
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
regex==2024.9.11
# via transformers
requests==2.32.3
# via
# huggingface-hub
# py-consul
# transformers
safetensors==0.4.5
# via transformers
sentencepiece==0.2.0
# via -r requirements.in
sniffio==1.3.1
# via anyio
starlette==0.41.0
# via
# fastapi
# prometheus-fastapi-instrumentator
sympy==1.13.3
# via torch
tokenizers==0.20.1
# via transformers
torch==2.4.1
# via -r requirements.in
tqdm==4.66.5
# via
# huggingface-hub
# transformers
transformers==4.46.1
# via -r requirements.in
typing-extensions==4.12.2
# via
# fastapi
# huggingface-hub
# pydantic
# pydantic-core
# torch
urllib3==2.2.3
# via requests
uvicorn==0.32.0
# via -r requirements.in
yarl==1.16.0
# via aiohttp
113 changes: 113 additions & 0 deletions examples/server/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import aiohttp, asyncio, logging, os, random, sys, tempfile, torch, traceback, uuid
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline

logger = logging.getLogger(__name__)

class TextToImageInput(BaseModel):
model: str
prompt: str
size: str | None = None
n: int | None = None

class HttpClient:
session: aiohttp.ClientSession = None

def start(self):
self.session = aiohttp.ClientSession()

async def stop(self):
await self.session.close()
self.session = None

def __call__(self) -> aiohttp.ClientSession:
assert self.session is not None
return self.session

class TextToImagePipeline:
pipeline: StableDiffusion3Pipeline = None
device: str = None

def start(self):
if torch.cuda.is_available():
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
logger.info("Loading CUDA")
self.device = "cuda"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
elif torch.backends.mps.is_available():
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
logger.info("Loading MPS for Mac M Series")
self.device = "mps"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")

app = FastAPI()
service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
image_dir = os.path.join(tempfile.gettempdir(), "images")
if not os.path.exists(image_dir):
os.makedirs(image_dir)
app.mount("/images", StaticFiles(directory=image_dir), name="images")
http_client = HttpClient()
shared_pipeline = TextToImagePipeline()

# Configure CORS settings
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
allow_headers=["*"], # Allows all headers
)

@app.on_event("startup")
def startup():
http_client.start()
shared_pipeline.start()

def save_image(image):
filename = "draw" + str(uuid.uuid4()).split('-')[0] + ".png"
image_path = os.path.join(image_dir, filename)
# write image to disk at image_path
logger.info(f"Saving image to {image_path}")
image.save(image_path)
return os.path.join(service_url, "images", filename)

@app.get('/')
@app.post('/')
@app.options('/')
async def base():
return "Welcome to Diffusers! Where you can use diffusion models to generate images"

@app.post("/v1/images/generations")
async def generate_image(image_input: TextToImageInput):
try:
loop = asyncio.get_event_loop()
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
generator =torch.Generator(device=shared_pipeline.device)
generator.manual_seed(random.randint(0, 10000000))
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
logger.info(f"output: {output}")
image_url = save_image(output.images[0])
return {"data": [{"url": image_url}]}
except Exception as e:
if isinstance(e, HTTPException):
raise e
elif hasattr(e, 'message'):
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())

if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
Loading