Skip to content

Commit b49ddb7

Browse files
committed
updated and tested stuff
1 parent fac9e16 commit b49ddb7

File tree

8 files changed

+25
-23
lines changed

8 files changed

+25
-23
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ docker build -t starlette-transformers:gpu -f dockerfiles/tensorflow/gpu/Dockerf
4040
```bash
4141
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering starlette-transformers:cpu
4242
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text starlette-transformers:gpu
43+
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=echarlaix/tiny-random-stable-diffusion-xl -e HF_TASK=text-to-image starlette-transformers:gpu
44+
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=stabilityai/stable-diffusion-xl-base-1.0 -e HF_TASK=text-to-image starlette-transformers:gpu
4345
docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository starlette-transformers:cpu
4446
```
4547

dockerfiles/pytorch/cpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RUN apt-get update \
2323
# install micromamba
2424
ENV MAMBA_ROOT_PREFIX=/opt/conda
2525
ENV PATH=/opt/conda/bin:$PATH
26-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
26+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
2727
&& touch /root/.bashrc \
2828
&& ./bin/micromamba shell init -s bash -p /opt/conda \
2929
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

dockerfiles/pytorch/gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/conda
2525
ENV PATH=/opt/conda/bin:$PATH
2626
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"
2727

28-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
28+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
2929
&& touch /root/.bashrc \
3030
&& ./bin/micromamba shell init -s bash -p /opt/conda \
3131
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

dockerfiles/pytorch/gpu/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- nvidia::cudatoolkit=11.7
77
- pytorch::pytorch=1.13.1=py3.9_cuda11.7*
88
- pip:
9-
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
9+
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
1010
- sentence_transformers==2.2.2
1111
- torchvision==0.14.1
1212
- diffusers==0.19.3

dockerfiles/tensorflow/cpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RUN apt-get update \
2323
# install micromamba
2424
ENV MAMBA_ROOT_PREFIX=/opt/conda
2525
ENV PATH=/opt/conda/bin:$PATH
26-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
26+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
2727
&& touch /root/.bashrc \
2828
&& ./bin/micromamba shell init -s bash -p /opt/conda \
2929
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

dockerfiles/tensorflow/gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/conda
2626
ENV PATH=/opt/conda/bin:$PATH
2727
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"
2828

29-
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
29+
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
3030
&& touch /root/.bashrc \
3131
&& ./bin/micromamba shell init -s bash -p /opt/conda \
3232
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,11 @@ def __call__(
4343
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")
4444

4545
# Call pipeline with parameters
46-
out = self.pipeline(prompt, num_images_per_prompt=1)
47-
46+
if self.pipeline.device.type == "cuda":
47+
with torch.autocast("cuda"):
48+
out = self.pipeline(prompt, num_images_per_prompt=1)
49+
else:
50+
out = self.pipeline(prompt, num_images_per_prompt=1)
4851
return out.images[0]
4952

5053

src/huggingface_inference_toolkit/utils.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
from pathlib import Path
55
from typing import Optional, Union
66

7-
from huggingface_hub import login, snapshot_download
7+
from huggingface_hub import HfApi, login, snapshot_download
88
from transformers import WhisperForConditionalGeneration, pipeline
99
from transformers.file_utils import is_tf_available, is_torch_available
1010
from transformers.pipelines import Conversation, Pipeline
1111

1212
from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
1313
from huggingface_inference_toolkit.diffusers_utils import (
14-
check_supported_pipeline,
1514
get_diffusers_pipeline,
1615
is_diffusers_available,
1716
)
@@ -46,11 +45,12 @@ def is_optimum_available():
4645
"pt": "pytorch*",
4746
"flax": "flax*",
4847
"rust": "rust*",
49-
"onnx": "*onnx",
48+
"onnx": "*onnx*",
5049
"safetensors": "*safetensors",
5150
"coreml": "*mlmodel",
5251
"tflite": "*tflite",
5352
"savedmodel": "*tar.gz",
53+
"openvino": "*openvino*",
5454
"ckpt": "*ckpt",
5555
}
5656

@@ -59,18 +59,8 @@ def create_artifact_filter(framework):
5959
"""
6060
Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
6161
"""
62-
ignore_regex_list = [
63-
"pytorch*",
64-
"tf*",
65-
"flax*",
66-
"rust*",
67-
"*onnx",
68-
"*safetensors",
69-
"*mlmodel",
70-
"*tflite",
71-
"*tar.gz",
72-
"*ckpt",
73-
]
62+
ignore_regex_list = list(framework2weight.values())
63+
7464
pattern = framework2weight.get(framework, None)
7565
if pattern in ignore_regex_list:
7666
ignore_regex_list.remove(pattern)
@@ -157,6 +147,13 @@ def _load_repository_from_hf(
157147
if not target_dir.exists():
158148
target_dir.mkdir(parents=True)
159149

150+
# check if safetensors weights are available
151+
if framework == "pytorch":
152+
files = HfApi().model_info(repository_id).siblings
153+
if any(f.rfilename.endswith("safetensors") for f in files):
154+
framework = "safetensors"
155+
156+
160157
# create regex to only include the framework specific weights
161158
ignore_regex = create_artifact_filter(framework)
162159
logger.info(f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }")
@@ -259,7 +256,7 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
259256
"sentence-ranking",
260257
]:
261258
hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
262-
elif is_diffusers_available() and check_supported_pipeline(model_dir) and task == "text-to-image":
259+
elif is_diffusers_available() and task == "text-to-image":
263260
hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
264261
else:
265262
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)

0 commit comments

Comments
 (0)