Skip to content

Commit 800a858

Browse files
Pass filename when uploading local files with HumeBatchClient (#76)
1 parent 8efc82e commit 800a858

File tree

4 files changed

+80
-17
lines changed

4 files changed

+80
-17
lines changed

hume/_batch/hume_batch_client.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Batch API client."""
22
import json
33
from pathlib import Path
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Tuple, Union
55

6-
import requests
6+
from requests import Session
77

88
from hume._batch.batch_job import BatchJob
99
from hume._batch.batch_job_details import BatchJobDetails
@@ -55,6 +55,7 @@ def __init__(
5555
timeout (int): Time in seconds before canceling long-running Hume API requests.
5656
"""
5757
self._timeout = timeout
58+
self._session = Session()
5859
super().__init__(api_key, *args, **kwargs)
5960

6061
@classmethod
@@ -118,7 +119,7 @@ def get_job_details(self, job_id: str) -> BatchJobDetails:
118119
BatchJobDetails: Batch job details.
119120
"""
120121
endpoint = self._construct_endpoint(f"jobs/{job_id}")
121-
response = requests.get(
122+
response = self._session.get(
122123
endpoint,
123124
timeout=self._timeout,
124125
headers=self._get_client_headers(),
@@ -148,7 +149,7 @@ def get_job_predictions(self, job_id: str) -> Any:
148149
Any: Batch job predictions.
149150
"""
150151
endpoint = self._construct_endpoint(f"jobs/{job_id}/predictions")
151-
response = requests.get(
152+
response = self._session.get(
152153
endpoint,
153154
timeout=self._timeout,
154155
headers=self._get_client_headers(),
@@ -179,7 +180,7 @@ def download_job_artifacts(self, job_id: str, filepath: Union[str, Path]) -> Non
179180
Any: Batch job artifacts.
180181
"""
181182
endpoint = self._construct_endpoint(f"jobs/{job_id}/artifacts")
182-
response = requests.get(
183+
response = self._session.get(
183184
endpoint,
184185
timeout=self._timeout,
185186
headers=self._get_client_headers(),
@@ -212,7 +213,7 @@ def _construct_request(
212213
def _submit_job(
213214
self,
214215
request_body: Any,
215-
files: Optional[List[Union[str, Path]]],
216+
filepaths: Optional[List[Union[str, Path]]],
216217
) -> BatchJob:
217218
"""Start a job for batch processing by passing a JSON request body.
218219
@@ -221,7 +222,7 @@ def _submit_job(
221222
222223
Args:
223224
request_body (Any): JSON request body to be passed to the batch API.
224-
files (Optional[List[Union[str, Path]]]): List of paths to files on the local disk to be processed.
225+
filepaths (Optional[List[Union[str, Path]]]): List of paths to files on the local disk to be processed.
225226
226227
Raises:
227228
HumeClientException: If the batch job fails to start.
@@ -231,21 +232,20 @@ def _submit_job(
231232
"""
232233
endpoint = self._construct_endpoint("jobs")
233234

234-
if files is None:
235-
response = requests.post(
235+
if filepaths is None:
236+
response = self._session.post(
236237
endpoint,
237238
json=request_body,
238239
timeout=self._timeout,
239240
headers=self._get_client_headers(),
240241
)
241242
else:
242-
post_files = [("file", Path(path).read_bytes()) for path in files]
243-
post_files.append(("json", json.dumps(request_body).encode("utf-8")))
244-
response = requests.post(
243+
form_data = self._get_multipart_form_data(request_body, filepaths)
244+
response = self._session.post(
245245
endpoint,
246246
timeout=self._timeout,
247247
headers=self._get_client_headers(),
248-
files=post_files,
248+
files=form_data,
249249
)
250250

251251
try:
@@ -268,3 +268,30 @@ def _submit_job(
268268
raise HumeClientException(f"Unexpected error when starting batch job: {body}")
269269

270270
return BatchJob(self, body["job_id"])
271+
272+
def _get_multipart_form_data(
273+
self,
274+
request_body: Any,
275+
filepaths: List[Union[str, Path]],
276+
) -> List[Tuple[str, Union[bytes, Tuple[str, bytes]]]]:
277+
"""Convert a list of filepaths into a list of multipart form data.
278+
279+
Multipart form data allows the client to attach files to the POST request,
280+
including both the raw file bytes and the filename.
281+
282+
Args:
283+
request_body (Any): JSON request body to be passed to the batch API.
284+
filepaths (List[Union[str, Path]]): List of paths to files on the local disk to be processed.
285+
286+
Returns:
287+
List[Tuple[str, Union[bytes, Tuple[str, bytes]]]]: A list of tuples representing
288+
the multipart form data for the POST request.
289+
"""
290+
form_data: List[Tuple[str, Union[bytes, Tuple[str, bytes]]]] = []
291+
for filepath in filepaths:
292+
path = Path(filepath)
293+
post_file = ("file", (path.name, path.read_bytes()))
294+
form_data.append(post_file)
295+
296+
form_data.append(("json", json.dumps(request_body).encode("utf-8")))
297+
return form_data

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ license = "Proprietary"
2525
name = "hume"
2626
readme = "README.md"
2727
repository = "https://github.com/HumeAI/hume-python-sdk"
28-
version = "0.3.4"
28+
version = "0.3.5"
2929

3030
[tool.poetry.dependencies]
3131
python = ">=3.8.1,<4"

tests/batch/test_hume_batch_client.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import MagicMock
22

33
import pytest
4-
from pytest import MonkeyPatch
4+
from pytest import MonkeyPatch, TempPathFactory
55

66
from hume import BatchJob, HumeBatchClient
77
from hume.models.config import BurstConfig, FaceConfig, LanguageConfig, ProsodyConfig
@@ -96,3 +96,36 @@ def test_language(self, batch_client: HumeBatchClient):
9696
def test_get_job(self, batch_client: HumeBatchClient):
9797
job = batch_client.get_job("mock-job-id")
9898
assert job.id == "mock-job-id"
99+
100+
def test_files(self, batch_client: HumeBatchClient):
101+
mock_filepath = "my-audio.mp3"
102+
config = ProsodyConfig(identify_speakers=True)
103+
job = batch_client.submit_job(None, [config], files=[mock_filepath])
104+
assert isinstance(job, BatchJob)
105+
assert job.id == "mock-job-id"
106+
batch_client._submit_job.assert_called_once_with(
107+
{
108+
"urls": None,
109+
"models": {
110+
"prosody": {
111+
"identify_speakers": True,
112+
},
113+
}
114+
},
115+
['my-audio.mp3'],
116+
)
117+
118+
def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory):
119+
dirpath = tmp_path_factory.mktemp("multipart")
120+
filepath = dirpath / "my-audio.mp3"
121+
with filepath.open("w") as f:
122+
f.write("I can't believe this test passed!")
123+
124+
request_body = {"mock": "body"}
125+
filepaths = [filepath]
126+
result = batch_client._get_multipart_form_data(request_body, filepaths)
127+
128+
assert result == [
129+
('file', ('my-audio.mp3', b"I can't believe this test passed!")),
130+
('json', b'{"mock": "body"}'),
131+
]

tests/batch/test_service_hume_batch_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def test_job_invalid_api_key(self, eval_data: EvalData, batch_client: HumeBatchC
121121
rehydrated_job = BatchJob(invalid_client, job.id)
122122
rehydrated_job.await_complete(10)
123123

124-
def test_local_file_upload(self, eval_data: EvalData, batch_client: HumeBatchClient,
125-
tmp_path_factory: TempPathFactory):
124+
def test_local_file_upload_simple(self, eval_data: EvalData, batch_client: HumeBatchClient,
125+
tmp_path_factory: TempPathFactory):
126126
data_url = eval_data["image-obama-face"]
127127
data_filepath = tmp_path_factory.mktemp("data-dir") / "obama.png"
128128
urlretrieve(data_url, data_filepath)
@@ -131,6 +131,9 @@ def test_local_file_upload(self, eval_data: EvalData, batch_client: HumeBatchCli
131131
job = batch_client.submit_job([], [config], files=[data_filepath])
132132
self.check_job(job, config, FaceConfig, job_files_dirpath, complete_config=False)
133133

134+
predictions = job.get_predictions()
135+
assert predictions[0]["source"]["filename"] == "obama.png"
136+
134137
def test_local_file_upload_configure(self, eval_data: EvalData, batch_client: HumeBatchClient,
135138
tmp_path_factory: TempPathFactory):
136139
data_url = eval_data["text-happy-place"]

0 commit comments

Comments
 (0)