Skip to content

Commit 17c2d07

Browse files
authored
rfctr improve partitioner typing (#2963)
**Summary** Remedy the persistent type errors when importing `unstructured`. Give the partitioner type annotations a general scrubbing while we're at it.
1 parent 39b74a2 commit 17c2d07

File tree

25 files changed

+332
-336
lines changed

25 files changed

+332
-336
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## 0.13.7-dev4
1+
## 0.13.7-dev5
22

33
### Enhancements
44

unstructured/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.13.7-dev4" # pragma: no cover
1+
__version__ = "0.13.7-dev5" # pragma: no cover

unstructured/partition/api.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
from __future__ import annotations
2+
13
import contextlib
24
import json
3-
from typing import (
4-
IO,
5-
List,
6-
Optional,
7-
)
5+
from typing import IO, Optional
86

97
import requests
108
from unstructured_client import UnstructuredClient
@@ -25,7 +23,7 @@ def partition_via_api(
2523
api_key: str = "",
2624
metadata_filename: Optional[str] = None,
2725
**request_kwargs,
28-
) -> List[Element]:
26+
) -> list[Element]:
2927
"""Partitions a document using the Unstructured REST API. This is equivalent to
3028
running the document through partition.
3129
@@ -84,10 +82,7 @@ def partition_via_api(
8482
"If file is specified in partition_via_api, "
8583
"metadata_filename must be specified as well.",
8684
)
87-
files = shared.Files(
88-
content=file,
89-
file_name=metadata_filename,
90-
)
85+
files = shared.Files(content=file, file_name=metadata_filename)
9186

9287
# NOTE(christine): Converts all list type parameters to JSON formatted strings
9388
# (e.g. ["image", "table"] -> '["image", "table"]')
@@ -96,10 +91,7 @@ def partition_via_api(
9691
if isinstance(v, list):
9792
request_kwargs[k] = json.dumps(v)
9893

99-
req = shared.PartitionParameters(
100-
files=files,
101-
**request_kwargs,
102-
)
94+
req = shared.PartitionParameters(files=files, **request_kwargs)
10395
response = sdk.general.partition(req)
10496

10597
if response.status_code == 200:
@@ -111,15 +103,15 @@ def partition_via_api(
111103

112104

113105
def partition_multiple_via_api(
114-
filenames: Optional[List[str]] = None,
115-
content_types: Optional[List[str]] = None,
116-
files: Optional[List[str]] = None,
117-
file_filenames: Optional[List[str]] = None,
106+
filenames: Optional[list[str]] = None,
107+
content_types: Optional[list[str]] = None,
108+
files: Optional[list[str]] = None,
109+
file_filenames: Optional[list[str]] = None,
118110
api_url: str = "https://api.unstructured.io/general/v0/general",
119111
api_key: str = "",
120-
metadata_filenames: Optional[List[str]] = None,
112+
metadata_filenames: Optional[list[str]] = None,
121113
**request_kwargs,
122-
) -> List[List[Element]]:
114+
) -> list[list[Element]]:
123115
"""Partitions multiple documents using the Unstructured REST API by batching
124116
the documents into a single HTTP request.
125117

unstructured/partition/auto.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
"""Provides partitioning with automatic file-type detection."""
2+
3+
from __future__ import annotations
4+
15
import io
2-
from typing import IO, Callable, Dict, List, Optional, Tuple
6+
from typing import IO, Any, Callable, Optional
37

48
import requests
59

6-
from unstructured.documents.elements import DataSourceMetadata
10+
from unstructured.documents.elements import DataSourceMetadata, Element
711
from unstructured.file_utils.filetype import (
812
FILETYPE_TO_MIMETYPE,
913
STR_TO_FILETYPE,
@@ -16,15 +20,13 @@
1620
from unstructured.partition.email import partition_email
1721
from unstructured.partition.html import partition_html
1822
from unstructured.partition.json import partition_json
19-
from unstructured.partition.lang import (
20-
check_language_args,
21-
)
23+
from unstructured.partition.lang import check_language_args
2224
from unstructured.partition.text import partition_text
2325
from unstructured.partition.utils.constants import PartitionStrategy
2426
from unstructured.partition.xml import partition_xml
2527
from unstructured.utils import dependency_exists
2628

27-
PARTITION_WITH_EXTRAS_MAP: Dict[str, Callable] = {}
29+
PARTITION_WITH_EXTRAS_MAP: dict[str, Callable[..., list[Element]]] = {}
2830

2931
if dependency_exists("pandas"):
3032
from unstructured.partition.csv import partition_csv
@@ -114,7 +116,7 @@
114116

115117
def _get_partition_with_extras(
116118
doc_type: str,
117-
partition_with_extras_map: Optional[Dict[str, Callable]] = None,
119+
partition_with_extras_map: Optional[dict[str, Callable[..., list[Element]]]] = None,
118120
):
119121
if partition_with_extras_map is None:
120122
partition_with_extras_map = PARTITION_WITH_EXTRAS_MAP
@@ -138,15 +140,15 @@ def partition(
138140
strategy: str = PartitionStrategy.AUTO,
139141
encoding: Optional[str] = None,
140142
paragraph_grouper: Optional[Callable[[str], str]] = None,
141-
headers: Dict[str, str] = {},
142-
skip_infer_table_types: List[str] = [],
143+
headers: dict[str, str] = {},
144+
skip_infer_table_types: list[str] = [],
143145
ssl_verify: bool = True,
144146
ocr_languages: Optional[str] = None, # changing to optional for deprecation
145-
languages: Optional[List[str]] = None,
147+
languages: Optional[list[str]] = None,
146148
detect_language_per_element: bool = False,
147149
pdf_infer_table_structure: bool = True,
148150
extract_images_in_pdf: bool = False,
149-
extract_image_block_types: Optional[List[str]] = None,
151+
extract_image_block_types: Optional[list[str]] = None,
150152
extract_image_block_output_dir: Optional[str] = None,
151153
extract_image_block_to_payload: bool = False,
152154
xml_keep_tags: bool = False,
@@ -157,7 +159,7 @@ def partition(
157159
model_name: Optional[str] = None, # to be deprecated
158160
date_from_file_object: bool = False,
159161
starting_page_number: int = 1,
160-
**kwargs,
162+
**kwargs: Any,
161163
):
162164
"""Partitions a document into its constituent elements. Will use libmagic to determine
163165
the file's type and route it to the appropriate partitioning function. Applies the default
@@ -422,8 +424,8 @@ def partition(
422424
elif filetype == FileType.PDF:
423425
_partition_pdf = _get_partition_with_extras("pdf")
424426
elements = _partition_pdf(
425-
filename=filename, # type: ignore
426-
file=file, # type: ignore
427+
filename=filename,
428+
file=file,
427429
url=None,
428430
include_page_breaks=include_page_breaks,
429431
infer_table_structure=infer_table_structure,
@@ -438,9 +440,10 @@ def partition(
438440
**kwargs,
439441
)
440442
elif filetype in IMAGE_FILETYPES:
441-
elements = partition_image(
442-
filename=filename, # type: ignore
443-
file=file, # type: ignore
443+
_partition_image = _get_partition_with_extras("image")
444+
elements = _partition_image(
445+
filename=filename,
446+
file=file,
444447
url=None,
445448
include_page_breaks=include_page_breaks,
446449
infer_table_structure=infer_table_structure,
@@ -557,10 +560,10 @@ def partition(
557560
def file_and_type_from_url(
558561
url: str,
559562
content_type: Optional[str] = None,
560-
headers: Dict[str, str] = {},
563+
headers: dict[str, str] = {},
561564
ssl_verify: bool = True,
562565
request_timeout: Optional[int] = None,
563-
) -> Tuple[io.BytesIO, Optional[FileType]]:
566+
) -> tuple[io.BytesIO, Optional[FileType]]:
564567
response = requests.get(url, headers=headers, verify=ssl_verify, timeout=request_timeout)
565568
file = io.BytesIO(response.content)
566569

@@ -575,7 +578,7 @@ def file_and_type_from_url(
575578

576579
def decide_table_extraction(
577580
filetype: Optional[FileType],
578-
skip_infer_table_types: List[str],
581+
skip_infer_table_types: list[str],
579582
pdf_infer_table_structure: bool,
580583
) -> bool:
581584
doc_type = filetype.name.lower() if filetype else None

unstructured/partition/common.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import datetime
77
from io import BufferedReader, BytesIO, TextIOWrapper
88
from tempfile import SpooledTemporaryFile
9-
from typing import IO, TYPE_CHECKING, Any, BinaryIO, List, Optional
9+
from typing import IO, TYPE_CHECKING, Any, Optional, TypeVar, cast
1010

1111
import emoji
1212
from tabulate import tabulate
@@ -191,14 +191,14 @@ def layout_list_to_list_items(
191191
coordinate_system: Optional[CoordinateSystem],
192192
metadata: Optional[ElementMetadata],
193193
detection_origin: Optional[str],
194-
) -> List[Element]:
194+
) -> list[Element]:
195195
"""Converts a list LayoutElement to a list of ListItem elements."""
196196
split_items = ENUMERATED_BULLETS_RE.split(text) if text else []
197197
# NOTE(robinson) - this means there wasn't a match for the enumerated bullets
198198
if len(split_items) == 1:
199199
split_items = UNICODE_BULLETS_RE.split(text) if text else []
200200

201-
list_items: List[Element] = []
201+
list_items: list[Element] = []
202202
for text_segment in split_items:
203203
if len(text_segment.strip()) > 0:
204204
# Both `coordinates` and `coordinate_system` must be present
@@ -216,13 +216,13 @@ def layout_list_to_list_items(
216216

217217

218218
def set_element_hierarchy(
219-
elements: List[Element], ruleset: dict[str, list[str]] = HIERARCHY_RULE_SET
219+
elements: list[Element], ruleset: dict[str, list[str]] = HIERARCHY_RULE_SET
220220
) -> list[Element]:
221221
"""Sets the parent_id for each element in the list of elements
222222
based on the element's category, depth and a ruleset
223223
224224
"""
225-
stack: List[Element] = []
225+
stack: list[Element] = []
226226
for element in elements:
227227
if element.metadata.parent_id is not None:
228228
continue
@@ -274,7 +274,7 @@ def add_element_metadata(
274274
coordinate_system: Optional[CoordinateSystem] = None,
275275
image_path: Optional[str] = None,
276276
detection_origin: Optional[str] = None,
277-
languages: Optional[List[str]] = None,
277+
languages: Optional[list[str]] = None,
278278
**kwargs: Any,
279279
) -> Element:
280280
"""Adds document metadata to the document element.
@@ -338,7 +338,7 @@ def remove_element_metadata(layout_elements) -> list[Element]:
338338
339339
Document metadata includes information like the filename, source url, and page number.
340340
"""
341-
elements: List[Element] = []
341+
elements: list[Element] = []
342342
metadata = ElementMetadata()
343343
for layout_element in layout_elements:
344344
element = normalize_layout_element(layout_element)
@@ -431,16 +431,25 @@ def exactly_one(**kwargs: Any) -> None:
431431
raise ValueError(message)
432432

433433

434-
def spooled_to_bytes_io_if_needed(
435-
file_obj: bytes | BinaryIO | SpooledTemporaryFile[bytes] | None,
436-
) -> bytes | BinaryIO | None:
437-
if isinstance(file_obj, SpooledTemporaryFile):
438-
file_obj.seek(0)
439-
contents = file_obj.read()
440-
return BytesIO(contents)
441-
else:
442-
# Return the original file object if it's not a SpooledTemporaryFile
443-
return file_obj
434+
_T = TypeVar("_T")
435+
436+
437+
def spooled_to_bytes_io_if_needed(file: _T | SpooledTemporaryFile[bytes]) -> _T | BytesIO:
438+
"""Convert `file` to `BytesIO` when it is a `SpooledTemporaryFile`.
439+
440+
Note that `file` does not need to be IO[bytes]. It can be `None` or `bytes` and this function
441+
will not complain.
442+
443+
In Python <3.11, `SpooledTemporaryFile` does not implement `.readable()` or `.seekable()` which
444+
triggers an exception when the file is loaded by certain packages. In particular, the stdlib
445+
`zipfile.Zipfile` raises on opening a `SpooledTemporaryFile` as does `Pandas.read_csv()`.
446+
"""
447+
if isinstance(file, SpooledTemporaryFile):
448+
file.seek(0)
449+
return BytesIO(cast(bytes, file.read()))
450+
451+
# -- return `file` unchanged otherwise --
452+
return file
444453

445454

446455
def convert_to_bytes(file: bytes | IO[bytes]) -> bytes:
@@ -537,16 +546,16 @@ def document_to_element_list(
537546
source_format: Optional[str] = None,
538547
detection_origin: Optional[str] = None,
539548
sort_mode: str = SORT_MODE_XY_CUT,
540-
languages: Optional[List[str]] = None,
549+
languages: Optional[list[str]] = None,
541550
starting_page_number: int = 1,
542551
**kwargs: Any,
543-
) -> List[Element]:
552+
) -> list[Element]:
544553
"""Converts a DocumentLayout object to a list of unstructured elements."""
545-
elements: List[Element] = []
554+
elements: list[Element] = []
546555

547556
num_pages = len(document.pages)
548557
for page_number, page in enumerate(document.pages, start=starting_page_number):
549-
page_elements: List[Element] = []
558+
page_elements: list[Element] = []
550559

551560
page_image_metadata = _get_page_image_metadata(page)
552561
image_format = page_image_metadata.get("format")
@@ -566,7 +575,7 @@ def document_to_element_list(
566575
infer_list_items=infer_list_items,
567576
source_format=source_format if source_format else "html",
568577
)
569-
if isinstance(element, List):
578+
if isinstance(element, list):
570579
for el in element:
571580
if last_modification_date:
572581
el.metadata.last_modified = last_modification_date
@@ -628,7 +637,7 @@ def document_to_element_list(
628637

629638

630639
def ocr_data_to_elements(
631-
ocr_data: List["LayoutElement"],
640+
ocr_data: list["LayoutElement"],
632641
image_size: tuple[int | float, int | float],
633642
common_metadata: Optional[ElementMetadata] = None,
634643
infer_list_items: bool = True,

unstructured/partition/csv.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from __future__ import annotations
2+
13
import csv
2-
from tempfile import SpooledTemporaryFile
3-
from typing import IO, BinaryIO, List, Optional, Union, cast
4+
from typing import IO, Any, Optional, cast
45

56
import pandas as pd
67
from lxml.html.soupparser import fromstring as soupparser_fromstring
@@ -29,18 +30,18 @@
2930
@add_chunking_strategy
3031
def partition_csv(
3132
filename: Optional[str] = None,
32-
file: Optional[Union[IO[bytes], SpooledTemporaryFile]] = None,
33+
file: Optional[IO[bytes]] = None,
3334
metadata_filename: Optional[str] = None,
3435
metadata_last_modified: Optional[str] = None,
3536
include_header: bool = False,
3637
include_metadata: bool = True,
3738
infer_table_structure: bool = True,
38-
languages: Optional[List[str]] = ["auto"],
39+
languages: Optional[list[str]] = ["auto"],
3940
# NOTE (jennings) partition_csv generates a single TableElement
4041
# so detect_language_per_element is not included as a param
4142
date_from_file_object: bool = False,
42-
**kwargs,
43-
) -> List[Element]:
43+
**kwargs: Any,
44+
) -> list[Element]:
4445
"""Partitions Microsoft Excel Documents in .csv format into its document elements.
4546
4647
Parameters
@@ -84,14 +85,12 @@ def partition_csv(
8485
last_modification_date = (
8586
get_last_modified_date_from_file(file) if date_from_file_object else None
8687
)
87-
f = spooled_to_bytes_io_if_needed(
88-
cast(Union[BinaryIO, SpooledTemporaryFile], file),
89-
)
88+
f = spooled_to_bytes_io_if_needed(file)
9089
delimiter = get_delimiter(file=f)
9190
table = pd.read_csv(f, header=header, sep=delimiter)
9291

9392
html_text = table.to_html(index=False, header=include_header, na_rep="")
94-
text = soupparser_fromstring(html_text).text_content()
93+
text = cast(str, soupparser_fromstring(html_text).text_content())
9594

9695
if include_metadata:
9796
metadata = ElementMetadata(

0 commit comments

Comments
 (0)