Skip to content

Commit 49e3c8e

Browse files
committed
chore: fixing lint
1 parent 05a7a36 commit 49e3c8e

File tree

4 files changed

+19
-29
lines changed

4 files changed

+19
-29
lines changed

src/unstructured_client/_hooks/custom/form_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def get_split_pdf_cache_tmp_data(
161161
return cache_tmp_data.lower() == "true"
162162

163163
def get_split_pdf_cache_tmp_data_dir(
164-
form_data: FormData, key: str, fallback_value: Path | str,
165-
) -> Path | str:
164+
form_data: FormData, key: str, fallback_value: str,
165+
) -> str:
166166
"""Retrieves the value for cache tmp data dir that should be used for splitting pdf.
167167
168168
In case given the number is not a "false" or "true" literal, it will use the
@@ -178,21 +178,19 @@ def get_split_pdf_cache_tmp_data_dir(
178178
"""
179179
cache_tmp_data_dir = form_data.get(key)
180180

181-
if not isinstance(cache_tmp_data_dir, str) and not isinstance(cache_tmp_data_dir, Path):
181+
if not isinstance(cache_tmp_data_dir, str):
182182
return fallback_value
183+
cache_tmp_data_path = Path(cache_tmp_data_dir)
183184

184-
if isinstance(cache_tmp_data_dir, str):
185-
cache_tmp_data_dir = Path(cache_tmp_data_dir)
186-
187-
if not cache_tmp_data_dir.exists():
185+
if not cache_tmp_data_path.exists():
188186
logger.warning(
189187
"'%s' does not exist. Using default value '%s'.",
190188
key,
191189
fallback_value,
192190
)
193191
return fallback_value
194192

195-
return cache_tmp_data_dir.resolve()
193+
return str(cache_tmp_data_path.resolve())
196194

197195

198196
def get_split_pdf_concurrency_level_param(

src/unstructured_client/_hooks/custom/pdf_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pypdf.errors import PdfReadError
99

1010
from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
11-
from unstructured_client.models import shared
1211

1312
logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
1413

src/unstructured_client/_hooks/custom/request_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import io
55
import json
66
import logging
7-
from typing import Tuple, Any, BinaryIO, cast, IO
7+
from typing import Tuple, Any, BinaryIO
88

99
import httpx
1010
from httpx._multipart import DataField, FileField
11-
from requests_toolbelt.multipart.encoder import MultipartEncoder # type: ignore
1211

1312
from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
1413
from unstructured_client._hooks.custom.form_utils import (
@@ -45,7 +44,7 @@ def get_multipart_stream_fields(request: httpx.Request) -> dict[str, Any]:
4544
return {}
4645
fields = request.stream.fields
4746

48-
mapped_fields = {}
47+
mapped_fields: dict[str, Any] = {}
4948
for field in fields:
5049
if isinstance(field, DataField):
5150
if "[]" in field.name:
@@ -114,7 +113,7 @@ def create_pdf_chunk_request(
114113
data = create_pdf_chunk_request_params(form_data, page_number)
115114
original_headers = prepare_request_headers(original_request.headers)
116115

117-
pdf_chunk_content = (
116+
pdf_chunk_content: BinaryIO | bytes = (
118117
pdf_chunk_file.getvalue()
119118
if isinstance(pdf_chunk_file, io.BytesIO)
120119
else pdf_chunk_file
@@ -135,6 +134,8 @@ def create_pdf_chunk_request(
135134
"multipart",
136135
shared.PartitionParameters,
137136
)
137+
if serialized_body is None:
138+
raise ValueError("Failed to serialize the request body.")
138139
return httpx.Request(
139140
method="POST",
140141
url=original_request.url or "",

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
from collections.abc import Awaitable
1212
from functools import partial
1313
from pathlib import Path
14-
from typing import Any, Coroutine, Optional, Tuple, Union, cast, Generator, BinaryIO, Callable
14+
from typing import Any, Coroutine, Optional, Tuple, Union, cast, Generator, BinaryIO
1515

1616
import aiofiles
1717
import httpx
1818
import nest_asyncio # type: ignore
1919
from httpx import AsyncClient
2020
from pypdf import PdfReader, PdfWriter
21-
from requests_toolbelt.multipart.decoder import MultipartDecoder # type: ignore
22-
from unstructured.chunking.dispatch import chunk
2321

2422
from unstructured_client._hooks.custom import form_utils, pdf_utils, request_utils
2523
from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
@@ -60,7 +58,7 @@ async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, httpx.Respons
6058

6159

6260
async def run_tasks(
63-
coroutines: list[Callable[[AsyncClient], Coroutine]],
61+
coroutines: list[partial[Coroutine[Any, Any, httpx.Response]]],
6462
allow_failed: bool = False
6563
) -> list[tuple[int, httpx.Response]]:
6664
"""Run a list of coroutines in parallel and return the results in order.
@@ -83,7 +81,7 @@ async def run_tasks(
8381
client_timeout = httpx.Timeout(60 * client_timeout_minutes)
8482

8583
async with httpx.AsyncClient(timeout=client_timeout) as client:
86-
armed_coroutines = [coro(async_client=client) for coro in coroutines]
84+
armed_coroutines = [coro(async_client=client) for coro in coroutines] # type: ignore
8785
if allow_failed:
8886
responses = await asyncio.gather(*armed_coroutines, return_exceptions=False)
8987
return list(enumerate(responses, 1))
@@ -157,12 +155,14 @@ def __init__(self) -> None:
157155
self.base_url: Optional[str] = None
158156
self.async_client: Optional[AsyncHttpClient] = None
159157
self.coroutines_to_execute: dict[
160-
str, list[Coroutine[Any, Any, httpx.Response]]
158+
str, list[partial[Coroutine[Any, Any, httpx.Response]]]
161159
] = {}
162160
self.api_successful_responses: dict[str, list[httpx.Response]] = {}
163161
self.api_failed_responses: dict[str, list[httpx.Response]] = {}
164162
self.tempdirs: dict[str, tempfile.TemporaryDirectory] = {}
165163
self.allow_failed: bool = DEFAULT_ALLOW_FAILED
164+
self.cache_tmp_data_feature: bool = DEFAULT_CACHE_TMP_DATA
165+
self.cache_tmp_data_dir: str = DEFAULT_CACHE_TMP_DATA_DIR
166166

167167
def sdk_init(
168168
self, base_url: str, client: HttpClient
@@ -266,15 +266,7 @@ def before_request(
266266
form_data = request_utils.get_multipart_stream_fields(request)
267267
if not form_data:
268268
return request
269-
# For future - avoid reading the request content as it might issue
270-
# OOM errors for large files. Instead, the `stream` (MultipartStream) parameter
271-
# should be used which contains the list of DataField or FileField objects
272-
# request_content = request.read()
273-
# request_body = request_content
274269

275-
276-
# decoded_body = MultipartDecoder(request_body, content_type)
277-
# form_data = form_utils.parse_form_data(decoded_body)
278270
split_pdf_page = form_data.get(PARTITION_FORM_SPLIT_PDF_PAGE_KEY)
279271
if split_pdf_page is None or split_pdf_page == "false":
280272
return request
@@ -505,7 +497,7 @@ def _get_pdf_chunk_paths(
505497
)
506498
self.tempdirs[operation_id] = tempdir
507499
tempdir_path = Path(tempdir.name)
508-
pdf_chunk_paths = []
500+
pdf_chunk_paths: list[Tuple[Path, int]] = []
509501
chunk_no = 0
510502
while offset < offset_end:
511503
chunk_no += 1
@@ -517,7 +509,7 @@ def _get_pdf_chunk_paths(
517509
new_pdf.add_page(page)
518510
with open(tempdir_path / f"chunk_{chunk_no}.pdf", "wb") as pdf_chunk:
519511
new_pdf.write(pdf_chunk)
520-
pdf_chunk_paths.append((pdf_chunk.name, offset))
512+
pdf_chunk_paths.append((Path(pdf_chunk.name), offset))
521513
offset += split_size
522514
return pdf_chunk_paths
523515

0 commit comments

Comments
 (0)