Skip to content

Commit 2a6e662

Browse files
committed
application/octet-stream support in content type deserialization
no reason not to accept it Signed-off-by: Raphael Glon <[email protected]>
1 parent 143fa85 commit 2a6e662

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/huggingface_inference_toolkit/serialization/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from huggingface_inference_toolkit.const import HF_TASK
12
from huggingface_inference_toolkit.serialization.audio_utils import Audioer
23
from huggingface_inference_toolkit.serialization.image_utils import Imager
34
from huggingface_inference_toolkit.serialization.json_utils import Jsoner
@@ -38,7 +39,13 @@
3839

3940
class ContentType:
4041
@staticmethod
41-
def get_deserializer(content_type: str):
42+
def get_deserializer(content_type: str, task: str):
43+
if content_type.lower().startswith("application/octet-stream"):
44+
if "audio" in task or "speech" in task:
45+
return Audioer
46+
elif "image" in task:
47+
return Imager
48+
4249
# Extract media type from content type
4350
extracted = content_type.split(";")[0]
4451
if extracted in content_type_mapping:

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ async def metrics(request):
8989
async def predict(request):
9090
global INFERENCE_HANDLERS
9191
try:
92+
task = request.path_params.get("task", HF_TASK)
9293
# extracts content from request
9394
content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE")).lower()
9495
# try to deserialize payload
95-
deserialized_body = ContentType.get_deserializer(content_type).deserialize(
96+
deserialized_body = ContentType.get_deserializer(content_type, task).deserialize(
9697
await request.body()
9798
)
9899
# checks if input schema is correct
@@ -108,7 +109,7 @@ async def predict(request):
108109
)
109110

110111
# We lazily load pipelines for alt tasks
111-
task = request.path_params.get("task", HF_TASK)
112+
112113
if task == "feature-extraction" and HF_TASK in [
113114
"sentence-similarity",
114115
"sentence-embeddings",

0 commit comments

Comments
 (0)