11"""Batch API client."""
22import json
33from pathlib import Path
4- from typing import Any , Dict , List , Optional , Union
4+ from typing import Any , Dict , List , Optional , Tuple , Union
55
6- import requests
6+ from requests import Session
77
88from hume ._batch .batch_job import BatchJob
99from hume ._batch .batch_job_details import BatchJobDetails
@@ -55,6 +55,7 @@ def __init__(
5555 timeout (int): Time in seconds before canceling long-running Hume API requests.
5656 """
5757 self ._timeout = timeout
58+ self ._session = Session ()
5859 super ().__init__ (api_key , * args , ** kwargs )
5960
6061 @classmethod
@@ -118,7 +119,7 @@ def get_job_details(self, job_id: str) -> BatchJobDetails:
118119 BatchJobDetails: Batch job details.
119120 """
120121 endpoint = self ._construct_endpoint (f"jobs/{ job_id } " )
121- response = requests .get (
122+ response = self . _session .get (
122123 endpoint ,
123124 timeout = self ._timeout ,
124125 headers = self ._get_client_headers (),
@@ -148,7 +149,7 @@ def get_job_predictions(self, job_id: str) -> Any:
148149 Any: Batch job predictions.
149150 """
150151 endpoint = self ._construct_endpoint (f"jobs/{ job_id } /predictions" )
151- response = requests .get (
152+ response = self . _session .get (
152153 endpoint ,
153154 timeout = self ._timeout ,
154155 headers = self ._get_client_headers (),
@@ -179,7 +180,7 @@ def download_job_artifacts(self, job_id: str, filepath: Union[str, Path]) -> Non
179180 Any: Batch job artifacts.
180181 """
181182 endpoint = self ._construct_endpoint (f"jobs/{ job_id } /artifacts" )
182- response = requests .get (
183+ response = self . _session .get (
183184 endpoint ,
184185 timeout = self ._timeout ,
185186 headers = self ._get_client_headers (),
@@ -212,7 +213,7 @@ def _construct_request(
212213 def _submit_job (
213214 self ,
214215 request_body : Any ,
215- files : Optional [List [Union [str , Path ]]],
216+ filepaths : Optional [List [Union [str , Path ]]],
216217 ) -> BatchJob :
217218 """Start a job for batch processing by passing a JSON request body.
218219
@@ -221,7 +222,7 @@ def _submit_job(
221222
222223 Args:
223224 request_body (Any): JSON request body to be passed to the batch API.
224- files (Optional[List[Union[str, Path]]]): List of paths to files on the local disk to be processed.
225+ filepaths (Optional[List[Union[str, Path]]]): List of paths to files on the local disk to be processed.
225226
226227 Raises:
227228 HumeClientException: If the batch job fails to start.
@@ -231,21 +232,20 @@ def _submit_job(
231232 """
232233 endpoint = self ._construct_endpoint ("jobs" )
233234
234- if files is None :
235- response = requests .post (
235+ if filepaths is None :
236+ response = self . _session .post (
236237 endpoint ,
237238 json = request_body ,
238239 timeout = self ._timeout ,
239240 headers = self ._get_client_headers (),
240241 )
241242 else :
242- post_files = [("file" , Path (path ).read_bytes ()) for path in files ]
243- post_files .append (("json" , json .dumps (request_body ).encode ("utf-8" )))
244- response = requests .post (
243+ form_data = self ._get_multipart_form_data (request_body , filepaths )
244+ response = self ._session .post (
245245 endpoint ,
246246 timeout = self ._timeout ,
247247 headers = self ._get_client_headers (),
248- files = post_files ,
248+ files = form_data ,
249249 )
250250
251251 try :
@@ -268,3 +268,30 @@ def _submit_job(
268268 raise HumeClientException (f"Unexpected error when starting batch job: { body } " )
269269
270270 return BatchJob (self , body ["job_id" ])
271+
272+ def _get_multipart_form_data (
273+ self ,
274+ request_body : Any ,
275+ filepaths : List [Union [str , Path ]],
276+ ) -> List [Tuple [str , Union [bytes , Tuple [str , bytes ]]]]:
277+ """Convert a list of filepaths into a list of multipart form data.
278+
279+ Multipart form data allows the client to attach files to the POST request,
280+ including both the raw file bytes and the filename.
281+
282+ Args:
283+ request_body (Any): JSON request body to be passed to the batch API.
284+ filepaths (List[Union[str, Path]]): List of paths to files on the local disk to be processed.
285+
286+ Returns:
287+ List[Tuple[str, Union[bytes, Tuple[str, bytes]]]]: A list of tuples representing
288+ the multipart form data for the POST request.
289+ """
290+ form_data : List [Tuple [str , Union [bytes , Tuple [str , bytes ]]]] = []
291+ for filepath in filepaths :
292+ path = Path (filepath )
293+ post_file = ("file" , (path .name , path .read_bytes ()))
294+ form_data .append (post_file )
295+
296+ form_data .append (("json" , json .dumps (request_body ).encode ("utf-8" )))
297+ return form_data
0 commit comments