Skip to content

Commit 8b8d3aa

Browse files
authored
Merge pull request #43 from huggingface/diffusers-auto
Add Stable Diffusion XL and Diffusers autopipeline
2 parents 00cdb53 + 9d53cc2 commit 8b8d3aa

File tree

10 files changed

+57
-69
lines changed

10 files changed

+57
-69
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ docker build -t starlette-transformers:gpu -f dockerfiles/tensorflow/gpu/Dockerf
4040
```bash
4141
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering starlette-transformers:cpu
4242
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text starlette-transformers:gpu
43+
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=echarlaix/tiny-random-stable-diffusion-xl -e HF_TASK=text-to-image starlette-transformers:gpu
44+
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=stabilityai/stable-diffusion-xl-base-1.0 -e HF_TASK=text-to-image starlette-transformers:gpu
4345
docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository starlette-transformers:cpu
4446
```
4547

dockerfiles/pytorch/cpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RUN apt-get update \
2323
# install micromamba
2424
ENV MAMBA_ROOT_PREFIX=/opt/conda
2525
ENV PATH=/opt/conda/bin:$PATH
26-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
26+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
2727
&& touch /root/.bashrc \
2828
&& ./bin/micromamba shell init -s bash -p /opt/conda \
2929
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

dockerfiles/pytorch/cpu/environment.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ dependencies:
55
- python=3.9.13
66
- pytorch::pytorch=1.13.1=py3.9_cpu_0
77
- pip:
8-
- transformers[sklearn,sentencepiece,audio,vision]==4.27.2
8+
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
99
- sentence_transformers==2.2.2
1010
- torchvision==0.14.1
11-
- diffusers==0.14.0
12-
- accelerate==0.17.1
11+
- diffusers==0.19.3
12+
- accelerate==0.21.0
13+
- safetensors

dockerfiles/pytorch/gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/conda
2525
ENV PATH=/opt/conda/bin:$PATH
2626
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"
2727

28-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
28+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
2929
&& touch /root/.bashrc \
3030
&& ./bin/micromamba shell init -s bash -p /opt/conda \
3131
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

dockerfiles/pytorch/gpu/environment.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ dependencies:
99
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
1010
- sentence_transformers==2.2.2
1111
- torchvision==0.14.1
12-
- diffusers==0.18.2
13-
- accelerate==0.21.0
12+
- diffusers==0.19.3
13+
- accelerate==0.21.0
14+
- safetensors

dockerfiles/tensorflow/cpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RUN apt-get update \
2323
# install micromamba
2424
ENV MAMBA_ROOT_PREFIX=/opt/conda
2525
ENV PATH=/opt/conda/bin:$PATH
26-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
26+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
2727
&& touch /root/.bashrc \
2828
&& ./bin/micromamba shell init -s bash -p /opt/conda \
2929
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

dockerfiles/tensorflow/gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/conda
2626
ENV PATH=/opt/conda/bin:$PATH
2727
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"
2828

29-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
29+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
3030
&& touch /root/.bashrc \
3131
&& ./bin/micromamba shell init -s bash -p /opt/conda \
3232
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import importlib.util
2-
import json
3-
import os
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)
46

57
_diffusers = importlib.util.find_spec("diffusers") is not None
68

@@ -11,60 +13,46 @@ def is_diffusers_available():
1113

1214
if is_diffusers_available():
1315
import torch
14-
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
15-
16-
17-
def check_supported_pipeline(model_dir):
18-
try:
19-
with open(os.path.join(model_dir, "model_index.json")) as json_file:
20-
data = json.load(json_file)
21-
if data["_class_name"] == "StableDiffusionPipeline":
22-
return True
23-
else:
24-
return False
25-
except Exception:
26-
return False
16+
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline
2717

2818

29-
class DiffusersPipelineImageToText:
19+
class IEAutoPipelineForText2Image:
3020
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
31-
self.pipeline = StableDiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float16)
21+
dtype = torch.float32
22+
if device == "cuda":
23+
dtype = torch.float16
24+
device_map = "auto" if device == "cuda" else None
25+
26+
self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
3227
# try to use DPMSolverMultistepScheduler
33-
try:
34-
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
35-
except Exception:
36-
pass
28+
if isinstance(self.pipeline, StableDiffusionPipeline):
29+
try:
30+
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
31+
except Exception:
32+
pass
3733
self.pipeline.to(device)
3834

3935
def __call__(
4036
self,
4137
prompt,
42-
num_inference_steps=25,
43-
guidance_scale=7.5,
44-
num_images_per_prompt=1,
45-
height=None,
46-
width=None,
47-
negative_prompt=None,
38+
**kwargs,
4839
):
4940
# TODO: add support for more images (Reason is correct output)
50-
num_images_per_prompt = 1
41+
if "num_images_per_prompt" in kwargs:
42+
kwargs.pop("num_images_per_prompt")
43+
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")
5144

5245
# Call pipeline with parameters
53-
out = self.pipeline(
54-
prompt,
55-
num_inference_steps=num_inference_steps,
56-
guidance_scale=guidance_scale,
57-
num_images_per_prompt=num_images_per_prompt,
58-
negative_prompt=negative_prompt,
59-
height=height,
60-
width=width,
61-
)
62-
46+
if self.pipeline.device.type == "cuda":
47+
with torch.autocast("cuda"):
48+
out = self.pipeline(prompt, num_images_per_prompt=1)
49+
else:
50+
out = self.pipeline(prompt, num_images_per_prompt=1)
6351
return out.images[0]
6452

6553

6654
DIFFUSERS_TASKS = {
67-
"text-to-image": DiffusersPipelineImageToText,
55+
"text-to-image": IEAutoPipelineForText2Image,
6856
}
6957

7058

src/huggingface_inference_toolkit/utils.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
from pathlib import Path
55
from typing import Optional, Union
66

7-
from huggingface_hub import login, snapshot_download
7+
from huggingface_hub import HfApi, login, snapshot_download
88
from transformers import WhisperForConditionalGeneration, pipeline
99
from transformers.file_utils import is_tf_available, is_torch_available
1010
from transformers.pipelines import Conversation, Pipeline
1111

1212
from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
1313
from huggingface_inference_toolkit.diffusers_utils import (
14-
check_supported_pipeline,
1514
get_diffusers_pipeline,
1615
is_diffusers_available,
1716
)
@@ -46,11 +45,12 @@ def is_optimum_available():
4645
"pt": "pytorch*",
4746
"flax": "flax*",
4847
"rust": "rust*",
49-
"onnx": "*onnx",
48+
"onnx": "*onnx*",
5049
"safetensors": "*safetensors",
5150
"coreml": "*mlmodel",
5251
"tflite": "*tflite",
5352
"savedmodel": "*tar.gz",
53+
"openvino": "*openvino*",
5454
"ckpt": "*ckpt",
5555
}
5656

@@ -59,18 +59,8 @@ def create_artifact_filter(framework):
5959
"""
6060
Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
6161
"""
62-
ignore_regex_list = [
63-
"pytorch*",
64-
"tf*",
65-
"flax*",
66-
"rust*",
67-
"*onnx",
68-
"*safetensors",
69-
"*mlmodel",
70-
"*tflite",
71-
"*tar.gz",
72-
"*ckpt",
73-
]
62+
ignore_regex_list = list(set(framework2weight.values()))
63+
7464
pattern = framework2weight.get(framework, None)
7565
if pattern in ignore_regex_list:
7666
ignore_regex_list.remove(pattern)
@@ -157,6 +147,12 @@ def _load_repository_from_hf(
157147
if not target_dir.exists():
158148
target_dir.mkdir(parents=True)
159149

150+
# check if safetensors weights are available
151+
if framework == "pytorch":
152+
files = HfApi().model_info(repository_id).siblings
153+
if any(f.rfilename.endswith("safetensors") for f in files):
154+
framework = "safetensors"
155+
160156
# create regex to only include the framework specific weights
161157
ignore_regex = create_artifact_filter(framework)
162158
logger.info(f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }")
@@ -259,7 +255,7 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
259255
"sentence-ranking",
260256
]:
261257
hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
262-
elif is_diffusers_available() and check_supported_pipeline(model_dir) and task == "text-to-image":
258+
elif is_diffusers_available() and task == "text-to-image":
263259
hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
264260
else:
265261
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)
@@ -287,8 +283,8 @@ def convert_params_to_int_or_bool(params):
287283
for k, v in params.items():
288284
if v.isnumeric():
289285
params[k] = int(v)
290-
if v == 'false':
286+
if v == "false":
291287
params[k] = False
292-
if v == 'true':
288+
if v == "true":
293289
params[k] = True
294290
return params

tests/unit/test_diffusers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from PIL import Image
44
from transformers.testing_utils import require_torch, slow
55

6-
from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler
7-
from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, DiffusersPipelineImageToText
6+
7+
from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, IEAutoPipelineForText2Image
88
from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline
99

1010

@@ -15,7 +15,7 @@ def test_get_diffusers_pipeline():
1515
"hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname, framework="pytorch"
1616
)
1717
pipe = get_pipeline("text-to-image", storage_dir.as_posix())
18-
assert isinstance(pipe, DiffusersPipelineImageToText)
18+
assert isinstance(pipe, IEAutoPipelineForText2Image)
1919

2020

2121
@slow

0 commit comments

Comments
 (0)