-
Couldn't load subscription status.
- Fork 6.5k
Add server example #9918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add server example #9918
Changes from 3 commits
b2d1c06
040806c
a91feec
be39807
a47bae3
3bf2c49
4f8ba17
82ab997
71f3638
36948da
c4733db
468bec9
a0bf884
0c11101
31711a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
|
|
||
| ## OpenAI Compatible `/v1/images/generations` Server | ||
|
|
||
| This is a concurrent, multithreaded solution for running a server that can generate images using the `diffusers` library. This examples uses the Stable Diffusion 3 pipeline, but you can use any pipeline that you would like by swapping out the model and pipeline to be the ones that you want to use. | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ### Installing Dependencies | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Start by going to the base of the repo and installing it with: | ||
| ``py | ||
| pip install . | ||
| ``` | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| The pipeline can then have its dependencies installed with: | ||
| ```py | ||
| pip install -f requirements.txt | ||
| ``` | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ### Running the server | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| This server can be run with: | ||
| ```py | ||
| python server.py | ||
| ``` | ||
| The server will be spun up at http://localhost:8000. You can `curl` this model with the following command: | ||
| ``` | ||
| curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations | ||
| ``` | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ### Upgrading Dependencies | ||
|
||
|
|
||
| If you need to upgrade some dependencies, you can do that with either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). With `uv`, this looks like: | ||
| ``` | ||
| uv pip compile requirements.in -o requirements.txt | ||
thealmightygrant marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ``` | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| torch~=2.4.0 | ||
| transformers==4.46.1 | ||
| sentencepiece | ||
| aiohttp | ||
| py-consul | ||
| prometheus_client >= 0.18.0 | ||
| prometheus-fastapi-instrumentator >= 7.0.0 | ||
| fastapi | ||
| uvicorn |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # This file was autogenerated by uv via the following command: | ||
| # uv pip compile requirements.in -o requirements.txt | ||
| aiohappyeyeballs==2.4.3 | ||
| # via aiohttp | ||
| aiohttp==3.10.10 | ||
| # via -r requirements.in | ||
| aiosignal==1.3.1 | ||
| # via aiohttp | ||
| annotated-types==0.7.0 | ||
| # via pydantic | ||
| anyio==4.6.2.post1 | ||
| # via starlette | ||
| attrs==24.2.0 | ||
| # via aiohttp | ||
| certifi==2024.8.30 | ||
| # via requests | ||
| charset-normalizer==3.4.0 | ||
| # via requests | ||
| click==8.1.7 | ||
| # via uvicorn | ||
| fastapi==0.115.3 | ||
| # via -r requirements.in | ||
| filelock==3.16.1 | ||
| # via | ||
| # huggingface-hub | ||
| # torch | ||
| # transformers | ||
| frozenlist==1.5.0 | ||
| # via | ||
| # aiohttp | ||
| # aiosignal | ||
| fsspec==2024.10.0 | ||
| # via | ||
| # huggingface-hub | ||
| # torch | ||
| h11==0.14.0 | ||
| # via uvicorn | ||
| huggingface-hub==0.26.1 | ||
| # via | ||
| # tokenizers | ||
| # transformers | ||
| idna==3.10 | ||
| # via | ||
| # anyio | ||
| # requests | ||
| # yarl | ||
| jinja2==3.1.4 | ||
| # via torch | ||
| markupsafe==3.0.2 | ||
| # via jinja2 | ||
| mpmath==1.3.0 | ||
| # via sympy | ||
| multidict==6.1.0 | ||
| # via | ||
| # aiohttp | ||
| # yarl | ||
| networkx==3.4.2 | ||
| # via torch | ||
| numpy==2.1.2 | ||
| # via transformers | ||
| packaging==24.1 | ||
| # via | ||
| # huggingface-hub | ||
| # transformers | ||
| prometheus-client==0.21.0 | ||
| # via | ||
| # -r requirements.in | ||
| # prometheus-fastapi-instrumentator | ||
| prometheus-fastapi-instrumentator==7.0.0 | ||
| # via -r requirements.in | ||
| propcache==0.2.0 | ||
| # via yarl | ||
| py-consul==1.5.3 | ||
| # via -r requirements.in | ||
| pydantic==2.9.2 | ||
| # via fastapi | ||
| pydantic-core==2.23.4 | ||
| # via pydantic | ||
| pyyaml==6.0.2 | ||
| # via | ||
| # huggingface-hub | ||
| # transformers | ||
| regex==2024.9.11 | ||
| # via transformers | ||
| requests==2.32.3 | ||
| # via | ||
| # huggingface-hub | ||
| # py-consul | ||
| # transformers | ||
| safetensors==0.4.5 | ||
| # via transformers | ||
| sentencepiece==0.2.0 | ||
| # via -r requirements.in | ||
| sniffio==1.3.1 | ||
| # via anyio | ||
| starlette==0.41.0 | ||
| # via | ||
| # fastapi | ||
| # prometheus-fastapi-instrumentator | ||
| sympy==1.13.3 | ||
| # via torch | ||
| tokenizers==0.20.1 | ||
| # via transformers | ||
| torch==2.4.1 | ||
| # via -r requirements.in | ||
| tqdm==4.66.5 | ||
| # via | ||
| # huggingface-hub | ||
| # transformers | ||
| transformers==4.46.1 | ||
| # via -r requirements.in | ||
| typing-extensions==4.12.2 | ||
| # via | ||
| # fastapi | ||
| # huggingface-hub | ||
| # pydantic | ||
| # pydantic-core | ||
| # torch | ||
| urllib3==2.2.3 | ||
| # via requests | ||
| uvicorn==0.32.0 | ||
| # via -r requirements.in | ||
| yarl==1.16.0 | ||
| # via aiohttp |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| from fastapi import FastAPI, HTTPException | ||
| from fastapi.staticfiles import StaticFiles | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from pydantic import BaseModel | ||
| import aiohttp, asyncio, logging, os, random, sys, tempfile, torch, traceback, uuid | ||
| from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class TextToImageInput(BaseModel): | ||
| model: str | ||
| prompt: str | ||
| size: str | None = None | ||
| n: int | None = None | ||
|
|
||
| class HttpClient: | ||
| session: aiohttp.ClientSession = None | ||
|
|
||
| def start(self): | ||
| self.session = aiohttp.ClientSession() | ||
|
|
||
| async def stop(self): | ||
| await self.session.close() | ||
| self.session = None | ||
|
|
||
| def __call__(self) -> aiohttp.ClientSession: | ||
| assert self.session is not None | ||
| return self.session | ||
|
|
||
| class TextToImagePipeline: | ||
| pipeline: StableDiffusion3Pipeline = None | ||
| device: str = None | ||
|
|
||
| def start(self): | ||
| if torch.cuda.is_available(): | ||
| model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large") | ||
| logger.info("Loading CUDA") | ||
| self.device = "cuda" | ||
| self.pipeline = StableDiffusion3Pipeline.from_pretrained( | ||
| model_path, | ||
| torch_dtype=torch.bfloat16, | ||
| ).to(device=self.device) | ||
| elif torch.backends.mps.is_available(): | ||
| model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium") | ||
| logger.info("Loading MPS for Mac M Series") | ||
| self.device = "mps" | ||
| self.pipeline = StableDiffusion3Pipeline.from_pretrained( | ||
| model_path, | ||
| torch_dtype=torch.bfloat16, | ||
| ).to(device=self.device) | ||
| else: | ||
| raise Exception("No CUDA or MPS device available") | ||
|
|
||
| app = FastAPI() | ||
| service_url = os.getenv("SERVICE_URL", "http://localhost:8000") | ||
| image_dir = os.path.join(tempfile.gettempdir(), "images") | ||
| if not os.path.exists(image_dir): | ||
| os.makedirs(image_dir) | ||
| app.mount("/images", StaticFiles(directory=image_dir), name="images") | ||
| http_client = HttpClient() | ||
| shared_pipeline = TextToImagePipeline() | ||
|
|
||
| # Configure CORS settings | ||
| app.add_middleware( | ||
| CORSMiddleware, | ||
| allow_origins=["*"], # Allows all origins | ||
| allow_credentials=True, | ||
| allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc. | ||
| allow_headers=["*"], # Allows all headers | ||
| ) | ||
|
|
||
| @app.on_event("startup") | ||
| def startup(): | ||
| http_client.start() | ||
| shared_pipeline.start() | ||
|
|
||
| def save_image(image): | ||
| filename = "draw" + str(uuid.uuid4()).split('-')[0] + ".png" | ||
| image_path = os.path.join(image_dir, filename) | ||
| # write image to disk at image_path | ||
| logger.info(f"Saving image to {image_path}") | ||
| image.save(image_path) | ||
| return os.path.join(service_url, "images", filename) | ||
|
|
||
| @app.get('/') | ||
| @app.post('/') | ||
| @app.options('/') | ||
| async def base(): | ||
| return "Welcome to Diffusers! Where you can use diffusion models to generate images" | ||
|
|
||
| @app.post("/v1/images/generations") | ||
| async def generate_image(image_input: TextToImageInput): | ||
| try: | ||
| loop = asyncio.get_event_loop() | ||
| scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config) | ||
| pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler) | ||
| generator =torch.Generator(device=shared_pipeline.device) | ||
| generator.manual_seed(random.randint(0, 10000000)) | ||
| output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator)) | ||
| logger.info(f"output: {output}") | ||
| image_url = save_image(output.images[0]) | ||
| return {"data": [{"url": image_url}]} | ||
| except Exception as e: | ||
| if isinstance(e, HTTPException): | ||
| raise e | ||
| elif hasattr(e, 'message'): | ||
| raise HTTPException(status_code=500, detail=e.message + traceback.format_exc()) | ||
| raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc()) | ||
|
|
||
| if __name__ == "__main__": | ||
| import uvicorn | ||
|
|
||
| uvicorn.run(app, host="0.0.0.0", port=8000) |
Uh oh!
There was an error while loading. Please reload this page.