Skip to content

Commit 00cdb53

Browse files
authored
Merge pull request #42 from huggingface/image-to-text
adding image-to-text pipeline
2 parents 2a7d1c4 + 4b2d83b commit 00cdb53

File tree

6 files changed

+27
-12
lines changed

6 files changed

+27
-12
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ docker build -t starlette-transformers:gpu -f dockerfiles/tensorflow/gpu/Dockerf
3939

4040
```bash
4141
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering starlette-transformers:cpu
42+
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text starlette-transformers:gpu
4243
docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository starlette-transformers:cpu
4344
```
4445

dockerfiles/pytorch/gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:11.7.0-devel-ubuntu22.04
1+
FROM nvidia/cuda:11.7.1-devel-ubuntu22.04
22

33
LABEL maintainer="Hugging Face"
44

dockerfiles/pytorch/gpu/environment.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ dependencies:
66
- nvidia::cudatoolkit=11.7
77
- pytorch::pytorch=1.13.1=py3.9_cuda11.7*
88
- pip:
9-
- transformers[sklearn,sentencepiece,audio,vision]==4.27.2
9+
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
1010
- sentence_transformers==2.2.2
1111
- torchvision==0.14.1
12-
- diffusers==0.14.0
13-
- accelerate==0.17.1
12+
- diffusers==0.18.2
13+
- accelerate==0.21.0

src/huggingface_inference_toolkit/handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __call__(self, data):
2525
"""
2626
inputs = data.pop("inputs", data)
2727
parameters = data.pop("parameters", None)
28-
2928
# pass inputs with all kwargs in data
3029
if parameters is not None:
3130
prediction = self.pipeline(inputs, **parameters)

src/huggingface_inference_toolkit/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
243243
"zero-shot-image-classification",
244244
}:
245245
kwargs["feature_extractor"] = model_dir
246+
elif task in {"image-to-text"}:
247+
pass
246248
else:
247249
kwargs["tokenizer"] = model_dir
248250

@@ -278,3 +280,15 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
278280
(rank + 1, token) for rank, token in enumerate(hf_pipeline.tokenizer.prefix_tokens[1:])
279281
]
280282
return hf_pipeline
283+
284+
285+
def convert_params_to_int_or_bool(params):
286+
"""Converts query params to int or bool if possible"""
287+
for k, v in params.items():
288+
if v.isnumeric():
289+
params[k] = int(v)
290+
if v == 'false':
291+
params[k] = False
292+
if v == 'true':
293+
params[k] = True
294+
return params

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler
2020
from huggingface_inference_toolkit.serialization.base import ContentType
2121
from huggingface_inference_toolkit.serialization.json_utils import Jsoner
22-
from huggingface_inference_toolkit.utils import _load_repository_from_hf
22+
from huggingface_inference_toolkit.utils import _load_repository_from_hf, convert_params_to_int_or_bool
2323

2424

2525
def config_logging(level=logging.INFO):
@@ -64,8 +64,6 @@ async def health(request):
6464

6565
async def predict(request):
6666
try:
67-
# tracks request time
68-
start_time = perf_counter()
6967
# extracts content from request
7068
content_type = request.headers.get("content-Type", None)
7169
# try to deserialize payload
@@ -74,13 +72,16 @@ async def predict(request):
7472
if "inputs" not in deserialized_body:
7573
raise ValueError(f"Body needs to provide a inputs key, recieved: {orjson.dumps(deserialized_body)}")
7674

75+
# check for query parameter and add them to the body
76+
if request.query_params and "parameters" not in deserialized_body:
77+
deserialized_body["parameters"] = convert_params_to_int_or_bool(dict(request.query_params))
78+
print(deserialized_body)
79+
80+
# tracks request time
81+
start_time = perf_counter()
7782
# run async not blocking call
7883
pred = await async_handler_call(inference_handler, deserialized_body)
79-
# run sync blocking call -> slighty faster for < 200ms prediction time
80-
# pred = inference_handler(deserialized_body)
81-
8284
# log request time
83-
# TODO: repalce with middleware
8485
logger.info(f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms")
8586

8687
# response extracts content from request

0 commit comments

Comments
 (0)