Skip to content

Commit 2572405

Browse files
fern-supportfern-api[bot]armandobelardonorman-codes
authored
🌿 manually modify start job from local file to match server expectations (#187)
Co-authored-by: fern-api <115122769+fern-api[bot]@users.noreply.github.com> Co-authored-by: armandobelardo <armandoubelardo@gmail.com> Co-authored-by: Norman Bukingolts <78187320+norman-codes@users.noreply.github.com>
1 parent 3143cb8 commit 2572405

File tree

2 files changed

+151
-7
lines changed

2 files changed

+151
-7
lines changed

src/hume/expression_measurement/batch/client_with_utils.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import aiofiles
22
import typing
3+
import json as jsonlib
4+
from json.decoder import JSONDecodeError
35

46
from ...core.request_options import RequestOptions
7+
from ...core.jsonable_encoder import jsonable_encoder
8+
from ... import core
59

10+
from .types.inference_base_request import InferenceBaseRequest
11+
from ...core.pydantic_utilities import parse_obj_as
12+
from .types.job_id import JobId
613
from .client import AsyncBatchClient, BatchClient
14+
from ...core.api_error import ApiError
715

816
class BatchClientWithUtils(BatchClient):
917
def get_and_write_job_artifacts(
@@ -47,6 +55,69 @@ def get_and_write_job_artifacts(
4755
for chunk in self.get_job_artifacts(id=id, request_options=request_options):
4856
f.write(chunk)
4957

58+
def start_inference_job_from_local_file(
59+
self,
60+
*,
61+
file: typing.List[core.File],
62+
json: typing.Optional[InferenceBaseRequest] = None,
63+
request_options: typing.Optional[RequestOptions] = None,
64+
) -> str:
65+
"""
66+
Start a new batch inference job.
67+
68+
Parameters
69+
----------
70+
file : typing.List[core.File]
71+
See core.File for more documentation
72+
73+
json : typing.Optional[InferenceBaseRequest]
74+
The inference job configuration.
75+
76+
request_options : typing.Optional[RequestOptions]
77+
Request-specific configuration.
78+
79+
Returns
80+
-------
81+
str
82+
83+
84+
Examples
85+
--------
86+
from hume import HumeClient
87+
88+
client = HumeClient(
89+
api_key="YOUR_API_KEY",
90+
)
91+
client.expression_measurement.batch.start_inference_job_from_local_file()
92+
"""
93+
files: typing.Dict[str, typing.Any] = {
94+
"file": file,
95+
}
96+
if json is not None:
97+
files["json"] = jsonlib.dumps(jsonable_encoder(json)).encode("utf-8")
98+
99+
_response = self._client_wrapper.httpx_client.request(
100+
"v0/batch/jobs",
101+
method="POST",
102+
files=files,
103+
request_options=request_options,
104+
)
105+
try:
106+
if 200 <= _response.status_code < 300:
107+
_parsed_response = typing.cast(
108+
JobId,
109+
parse_obj_as(
110+
type_=JobId, # type: ignore
111+
object_=_response.json(),
112+
),
113+
)
114+
return _parsed_response.job_id
115+
_response_json = _response.json()
116+
except JSONDecodeError:
117+
raise ApiError(status_code=_response.status_code, body=_response.text)
118+
raise ApiError(status_code=_response.status_code, body=_response_json)
119+
120+
50121
class AsyncBatchClientWithUtils(AsyncBatchClient):
51122
async def get_and_write_job_artifacts(
52123
self,
@@ -87,4 +158,66 @@ async def get_and_write_job_artifacts(
87158
"""
88159
async with aiofiles.open(file_name, mode='wb') as f:
89160
async for chunk in self.get_job_artifacts(id=id, request_options=request_options):
90-
await f.write(chunk)
161+
await f.write(chunk)
162+
163+
async def start_inference_job_from_local_file(
164+
self,
165+
*,
166+
file: typing.List[core.File],
167+
json: typing.Optional[InferenceBaseRequest] = None,
168+
request_options: typing.Optional[RequestOptions] = None,
169+
) -> str:
170+
"""
171+
Start a new batch inference job.
172+
173+
Parameters
174+
----------
175+
file : typing.List[core.File]
176+
See core.File for more documentation
177+
178+
json : typing.Optional[InferenceBaseRequest]
179+
The inference job configuration.
180+
181+
request_options : typing.Optional[RequestOptions]
182+
Request-specific configuration.
183+
184+
Returns
185+
-------
186+
str
187+
188+
189+
Examples
190+
--------
191+
from hume import HumeClient
192+
193+
client = HumeClient(
194+
api_key="YOUR_API_KEY",
195+
)
196+
client.expression_measurement.batch.start_inference_job_from_local_file()
197+
"""
198+
files: typing.Dict[str, typing.Any] = {
199+
"file": file,
200+
}
201+
if json is not None:
202+
files["json"] = jsonlib.dumps(jsonable_encoder(json)).encode("utf-8")
203+
204+
_response = await self._client_wrapper.httpx_client.request(
205+
"v0/batch/jobs",
206+
method="POST",
207+
files=files,
208+
request_options=request_options,
209+
)
210+
try:
211+
if 200 <= _response.status_code < 300:
212+
_parsed_response = typing.cast(
213+
JobId,
214+
parse_obj_as(
215+
type_=JobId, # type: ignore
216+
object_=_response.json(),
217+
),
218+
)
219+
return _parsed_response.job_id
220+
_response_json = _response.json()
221+
except JSONDecodeError:
222+
raise ApiError(status_code=_response.status_code, body=_response.text)
223+
raise ApiError(status_code=_response.status_code, body=_response_json)

tests/custom/test_client.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import pytest
22
import aiofiles
33

4-
from hume.client import AsyncHumeClient
4+
from hume.client import AsyncHumeClient, HumeClient
5+
from hume.expression_measurement.batch.types.face import Face
6+
from hume.expression_measurement.batch.types.inference_base_request import InferenceBaseRequest
7+
from hume.expression_measurement.batch.types.models import Models
58

6-
# Get started with writing tests with pytest at https://docs.pytest.org
7-
@pytest.mark.skip(reason="Unimplemented")
8-
def test_client() -> None:
9-
assert True == True
109

1110
@pytest.mark.skip(reason="CI does not have authentication.")
1211
async def test_write_job_artifacts() -> None:
@@ -20,4 +19,16 @@ async def test_get_job_predictions() -> None:
2019
client = AsyncHumeClient(api_key="MY_API_KEY")
2120
await client.expression_measurement.batch.get_job_predictions(id="my-job-id", request_options={
2221
"max_retries": 3,
23-
})
22+
})
23+
24+
@pytest.mark.skip(reason="CI does not have authentication.")
25+
async def test_start_inference_job_from_local_file() -> None:
26+
client = HumeClient(api_key="MY_API_KEY")
27+
client.expression_measurement.batch.start_inference_job_from_local_file(
28+
file=[],
29+
json=InferenceBaseRequest(
30+
models=Models(
31+
face=Face()
32+
)
33+
)
34+
)

0 commit comments

Comments
 (0)