Skip to content

Commit a4bc2e5

Browse files
[Inference Client] Factorize inference payload build (#2601)
* Factorize inference payload build and add test * Add comments * Add method description * fix style * fix style again * fix prepare payload helper * experiment: try old version of workflow * revert experiment: try old version of workflow * Add docstring * update docstring * simplify json payload construction when inputs is a dict * ignore mypy str bytes warning * fix encoding condition * remove unnecessary checks for parameters
1 parent c9d7865 commit a4bc2e5

File tree

4 files changed

+306
-277
lines changed

4 files changed

+306
-277
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 60 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
_get_unsupported_text_generation_kwargs,
5858
_import_numpy,
5959
_open_as_binary,
60+
_prepare_payload,
6061
_set_unsupported_text_generation_kwargs,
6162
_stream_chat_completion_response,
6263
_stream_text_generation_response,
@@ -364,18 +365,8 @@ def audio_classification(
364365
```
365366
"""
366367
parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
367-
if all(parameter is None for parameter in parameters.values()):
368-
# if no parameters are provided, send audio as raw data
369-
data = audio
370-
payload: Optional[Dict[str, Any]] = None
371-
else:
372-
# Or some parameters are provided -> send audio as base64 encoded string
373-
data = None
374-
payload = {"inputs": _b64_encode(audio)}
375-
for key, value in parameters.items():
376-
if value is not None:
377-
payload.setdefault("parameters", {})[key] = value
378-
response = self.post(json=payload, data=data, model=model, task="audio-classification")
368+
payload = _prepare_payload(audio, parameters=parameters, expect_binary=True)
369+
response = self.post(**payload, model=model, task="audio-classification")
379370
return AudioClassificationOutputElement.parse_obj_as_list(response)
380371

381372
def audio_to_audio(
@@ -988,7 +979,7 @@ def document_question_answering(
988979
[DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16, words=None)]
989980
```
990981
"""
991-
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
982+
inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
992983
parameters = {
993984
"doc_stride": doc_stride,
994985
"handle_impossible_answer": handle_impossible_answer,
@@ -999,10 +990,8 @@ def document_question_answering(
999990
"top_k": top_k,
1000991
"word_boxes": word_boxes,
1001992
}
1002-
for key, value in parameters.items():
1003-
if value is not None:
1004-
payload.setdefault("parameters", {})[key] = value
1005-
response = self.post(json=payload, model=model, task="document-question-answering")
993+
payload = _prepare_payload(inputs, parameters=parameters)
994+
response = self.post(**payload, model=model, task="document-question-answering")
1006995
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)
1007996

1008997
def feature_extraction(
@@ -1060,17 +1049,14 @@ def feature_extraction(
10601049
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
10611050
```
10621051
"""
1063-
payload: Dict = {"inputs": text}
10641052
parameters = {
10651053
"normalize": normalize,
10661054
"prompt_name": prompt_name,
10671055
"truncate": truncate,
10681056
"truncation_direction": truncation_direction,
10691057
}
1070-
for key, value in parameters.items():
1071-
if value is not None:
1072-
payload.setdefault("parameters", {})[key] = value
1073-
response = self.post(json=payload, model=model, task="feature-extraction")
1058+
payload = _prepare_payload(text, parameters=parameters)
1059+
response = self.post(**payload, model=model, task="feature-extraction")
10741060
np = _import_numpy()
10751061
return np.array(_bytes_to_dict(response), dtype="float32")
10761062

@@ -1119,12 +1105,9 @@ def fill_mask(
11191105
]
11201106
```
11211107
"""
1122-
payload: Dict = {"inputs": text}
11231108
parameters = {"targets": targets, "top_k": top_k}
1124-
for key, value in parameters.items():
1125-
if value is not None:
1126-
payload.setdefault("parameters", {})[key] = value
1127-
response = self.post(json=payload, model=model, task="fill-mask")
1109+
payload = _prepare_payload(text, parameters=parameters)
1110+
response = self.post(**payload, model=model, task="fill-mask")
11281111
return FillMaskOutputElement.parse_obj_as_list(response)
11291112

11301113
def image_classification(
@@ -1166,19 +1149,8 @@ def image_classification(
11661149
```
11671150
"""
11681151
parameters = {"function_to_apply": function_to_apply, "top_k": top_k}
1169-
1170-
if all(parameter is None for parameter in parameters.values()):
1171-
data = image
1172-
payload: Optional[Dict[str, Any]] = None
1173-
1174-
else:
1175-
data = None
1176-
payload = {"inputs": _b64_encode(image)}
1177-
for key, value in parameters.items():
1178-
if value is not None:
1179-
payload.setdefault("parameters", {})[key] = value
1180-
1181-
response = self.post(json=payload, data=data, model=model, task="image-classification")
1152+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1153+
response = self.post(**payload, model=model, task="image-classification")
11821154
return ImageClassificationOutputElement.parse_obj_as_list(response)
11831155

11841156
def image_segmentation(
@@ -1237,18 +1209,8 @@ def image_segmentation(
12371209
"subtask": subtask,
12381210
"threshold": threshold,
12391211
}
1240-
if all(parameter is None for parameter in parameters.values()):
1241-
# if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image
1242-
data = image
1243-
payload: Optional[Dict[str, Any]] = None
1244-
else:
1245-
# if parameters are provided, the image needs to be a base64-encoded string
1246-
data = None
1247-
payload = {"inputs": _b64_encode(image)}
1248-
for key, value in parameters.items():
1249-
if value is not None:
1250-
payload.setdefault("parameters", {})[key] = value
1251-
response = self.post(json=payload, data=data, model=model, task="image-segmentation")
1212+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1213+
response = self.post(**payload, model=model, task="image-segmentation")
12521214
output = ImageSegmentationOutputElement.parse_obj_as_list(response)
12531215
for item in output:
12541216
item.mask = _b64_to_image(item.mask) # type: ignore [assignment]
@@ -1323,19 +1285,8 @@ def image_to_image(
13231285
"guidance_scale": guidance_scale,
13241286
**kwargs,
13251287
}
1326-
if all(parameter is None for parameter in parameters.values()):
1327-
# Either only an image to send => send as raw bytes
1328-
data = image
1329-
payload: Optional[Dict[str, Any]] = None
1330-
else:
1331-
# if parameters are provided, the image needs to be a base64-encoded string
1332-
data = None
1333-
payload = {"inputs": _b64_encode(image)}
1334-
for key, value in parameters.items():
1335-
if value is not None:
1336-
payload.setdefault("parameters", {})[key] = value
1337-
1338-
response = self.post(json=payload, data=data, model=model, task="image-to-image")
1288+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1289+
response = self.post(**payload, model=model, task="image-to-image")
13391290
return _bytes_to_image(response)
13401291

13411292
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
@@ -1493,25 +1444,15 @@ def object_detection(
14931444
```py
14941445
>>> from huggingface_hub import InferenceClient
14951446
>>> client = InferenceClient()
1496-
>>> client.object_detection("people.jpg"):
1447+
>>> client.object_detection("people.jpg")
14971448
[ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...]
14981449
```
14991450
"""
15001451
parameters = {
15011452
"threshold": threshold,
15021453
}
1503-
if all(parameter is None for parameter in parameters.values()):
1504-
# if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image
1505-
data = image
1506-
payload: Optional[Dict[str, Any]] = None
1507-
else:
1508-
# if parameters are provided, the image needs to be a base64-encoded string
1509-
data = None
1510-
payload = {"inputs": _b64_encode(image)}
1511-
for key, value in parameters.items():
1512-
if value is not None:
1513-
payload.setdefault("parameters", {})[key] = value
1514-
response = self.post(json=payload, data=data, model=model, task="object-detection")
1454+
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
1455+
response = self.post(**payload, model=model, task="object-detection")
15151456
return ObjectDetectionOutputElement.parse_obj_as_list(response)
15161457

15171458
def question_answering(
@@ -1587,12 +1528,10 @@ def question_answering(
15871528
"max_seq_len": max_seq_len,
15881529
"top_k": top_k,
15891530
}
1590-
payload: Dict[str, Any] = {"question": question, "context": context}
1591-
for key, value in parameters.items():
1592-
if value is not None:
1593-
payload.setdefault("parameters", {})[key] = value
1531+
inputs: Dict[str, Any] = {"question": question, "context": context}
1532+
payload = _prepare_payload(inputs, parameters=parameters)
15941533
response = self.post(
1595-
json=payload,
1534+
**payload,
15961535
model=model,
15971536
task="question-answering",
15981537
)
@@ -1700,19 +1639,14 @@ def summarization(
17001639
SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....")
17011640
```
17021641
"""
1703-
payload: Dict[str, Any] = {"inputs": text}
1704-
if parameters is not None:
1705-
payload["parameters"] = parameters
1706-
else:
1642+
if parameters is None:
17071643
parameters = {
17081644
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
17091645
"generate_parameters": generate_parameters,
17101646
"truncation": truncation,
17111647
}
1712-
for key, value in parameters.items():
1713-
if value is not None:
1714-
payload.setdefault("parameters", {})[key] = value
1715-
response = self.post(json=payload, model=model, task="summarization")
1648+
payload = _prepare_payload(text, parameters=parameters)
1649+
response = self.post(**payload, model=model, task="summarization")
17161650
return SummarizationOutput.parse_obj_as_list(response)[0]
17171651

17181652
def table_question_answering(
@@ -1757,15 +1691,13 @@ def table_question_answering(
17571691
TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE')
17581692
```
17591693
"""
1760-
payload: Dict[str, Any] = {
1694+
inputs = {
17611695
"query": query,
17621696
"table": table,
17631697
}
1764-
1765-
if parameters is not None:
1766-
payload["parameters"] = parameters
1698+
payload = _prepare_payload(inputs, parameters=parameters)
17671699
response = self.post(
1768-
json=payload,
1700+
**payload,
17691701
model=model,
17701702
task="table-question-answering",
17711703
)
@@ -1813,7 +1745,11 @@ def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str]
18131745
["5", "5", "5"]
18141746
```
18151747
"""
1816-
response = self.post(json={"table": table}, model=model, task="tabular-classification")
1748+
response = self.post(
1749+
json={"table": table},
1750+
model=model,
1751+
task="tabular-classification",
1752+
)
18171753
return _bytes_to_list(response)
18181754

18191755
def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
@@ -1899,15 +1835,16 @@ def text_classification(
18991835
]
19001836
```
19011837
"""
1902-
payload: Dict[str, Any] = {"inputs": text}
19031838
parameters = {
19041839
"function_to_apply": function_to_apply,
19051840
"top_k": top_k,
19061841
}
1907-
for key, value in parameters.items():
1908-
if value is not None:
1909-
payload.setdefault("parameters", {})[key] = value
1910-
response = self.post(json=payload, model=model, task="text-classification")
1842+
payload = _prepare_payload(text, parameters=parameters)
1843+
response = self.post(
1844+
**payload,
1845+
model=model,
1846+
task="text-classification",
1847+
)
19111848
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
19121849

19131850
@overload
@@ -2481,7 +2418,7 @@ def text_to_image(
24812418
>>> image.save("better_astronaut.png")
24822419
```
24832420
"""
2484-
payload = {"inputs": prompt}
2421+
24852422
parameters = {
24862423
"negative_prompt": negative_prompt,
24872424
"height": height,
@@ -2493,10 +2430,8 @@ def text_to_image(
24932430
"seed": seed,
24942431
**kwargs,
24952432
}
2496-
for key, value in parameters.items():
2497-
if value is not None:
2498-
payload.setdefault("parameters", {})[key] = value # type: ignore
2499-
response = self.post(json=payload, model=model, task="text-to-image")
2433+
payload = _prepare_payload(prompt, parameters=parameters)
2434+
response = self.post(**payload, model=model, task="text-to-image")
25002435
return _bytes_to_image(response)
25012436

25022437
def text_to_speech(
@@ -2599,7 +2534,6 @@ def text_to_speech(
25992534
>>> Path("hello_world.flac").write_bytes(audio)
26002535
```
26012536
"""
2602-
payload: Dict[str, Any] = {"inputs": text}
26032537
parameters = {
26042538
"do_sample": do_sample,
26052539
"early_stopping": early_stopping,
@@ -2618,10 +2552,8 @@ def text_to_speech(
26182552
"typical_p": typical_p,
26192553
"use_cache": use_cache,
26202554
}
2621-
for key, value in parameters.items():
2622-
if value is not None:
2623-
payload.setdefault("parameters", {})[key] = value
2624-
response = self.post(json=payload, model=model, task="text-to-speech")
2555+
payload = _prepare_payload(text, parameters=parameters)
2556+
response = self.post(**payload, model=model, task="text-to-speech")
26252557
return response
26262558

26272559
def token_classification(
@@ -2683,17 +2615,15 @@ def token_classification(
26832615
]
26842616
```
26852617
"""
2686-
payload: Dict[str, Any] = {"inputs": text}
2618+
26872619
parameters = {
26882620
"aggregation_strategy": aggregation_strategy,
26892621
"ignore_labels": ignore_labels,
26902622
"stride": stride,
26912623
}
2692-
for key, value in parameters.items():
2693-
if value is not None:
2694-
payload.setdefault("parameters", {})[key] = value
2624+
payload = _prepare_payload(text, parameters=parameters)
26952625
response = self.post(
2696-
json=payload,
2626+
**payload,
26972627
model=model,
26982628
task="token-classification",
26992629
)
@@ -2769,18 +2699,15 @@ def translation(
27692699

27702700
if src_lang is None and tgt_lang is not None:
27712701
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
2772-
payload: Dict[str, Any] = {"inputs": text}
27732702
parameters = {
27742703
"src_lang": src_lang,
27752704
"tgt_lang": tgt_lang,
27762705
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
27772706
"truncation": truncation,
27782707
"generate_parameters": generate_parameters,
27792708
}
2780-
for key, value in parameters.items():
2781-
if value is not None:
2782-
payload.setdefault("parameters", {})[key] = value
2783-
response = self.post(json=payload, model=model, task="translation")
2709+
payload = _prepare_payload(text, parameters=parameters)
2710+
response = self.post(**payload, model=model, task="translation")
27842711
return TranslationOutput.parse_obj_as_list(response)[0]
27852712

27862713
def visual_question_answering(
@@ -2921,15 +2848,14 @@ def zero_shot_classification(
29212848
```
29222849
"""
29232850

2924-
parameters = {"candidate_labels": labels, "multi_label": multi_label}
2925-
if hypothesis_template is not None:
2926-
parameters["hypothesis_template"] = hypothesis_template
2927-
2851+
parameters = {
2852+
"candidate_labels": labels,
2853+
"multi_label": multi_label,
2854+
"hypothesis_template": hypothesis_template,
2855+
}
2856+
payload = _prepare_payload(text, parameters=parameters)
29282857
response = self.post(
2929-
json={
2930-
"inputs": text,
2931-
"parameters": parameters,
2932-
},
2858+
**payload,
29332859
task="zero-shot-classification",
29342860
model=model,
29352861
)
@@ -2986,13 +2912,11 @@ def zero_shot_image_classification(
29862912
if len(labels) < 2:
29872913
raise ValueError("You must specify at least 2 classes to compare.")
29882914

2989-
payload = {
2990-
"inputs": {"image": _b64_encode(image), "candidateLabels": ",".join(labels)},
2991-
}
2992-
if hypothesis_template is not None:
2993-
payload.setdefault("parameters", {})["hypothesis_template"] = hypothesis_template
2915+
inputs = {"image": _b64_encode(image), "candidateLabels": ",".join(labels)}
2916+
parameters = {"hypothesis_template": hypothesis_template}
2917+
payload = _prepare_payload(inputs, parameters=parameters)
29942918
response = self.post(
2995-
json=payload,
2919+
**payload,
29962920
model=model,
29972921
task="zero-shot-image-classification",
29982922
)

0 commit comments

Comments
 (0)