Skip to content

Commit 2755807

Browse files
committed
WIP
1 parent 4ffd8bc commit 2755807

File tree

1 file changed

+99
-73
lines changed

1 file changed

+99
-73
lines changed

prepline_general/api/general.py

Lines changed: 99 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from concurrent.futures import ThreadPoolExecutor
2020
from functools import partial
2121
import pypdf
22-
from pypdf import PdfReader, PdfWriter
22+
from pypdf import PdfReader, PdfWriter, PageObject
2323
import psutil
2424
import requests
2525
import backoff
26-
from typing import Optional, Mapping
26+
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple
2727
from fastapi import (
2828
status,
2929
FastAPI,
@@ -40,6 +40,7 @@
4040
import secrets
4141

4242
# Unstructured Imports
43+
from unstructured.documents.elements import Element
4344
from unstructured.partition.auto import partition
4445
from unstructured.staging.base import (
4546
convert_to_isd,
@@ -53,18 +54,7 @@
5354
app = FastAPI()
5455
router = APIRouter()
5556

56-
57-
def is_expected_response_type(media_type, response_type):
58-
if media_type == "application/json" and response_type not in [dict, list]:
59-
return True
60-
elif media_type == "text/csv" and response_type != str:
61-
return True
62-
else:
63-
return False
64-
65-
66-
logger = logging.getLogger("unstructured_api")
67-
57+
IS_CHIPPER_PROCESSING = False
6858

6959
DEFAULT_MIMETYPES = (
7060
"application/pdf,application/msword,image/jpeg,image/png,text/markdown,"
@@ -90,7 +80,21 @@ def is_expected_response_type(media_type, response_type):
9080
os.environ["UNSTRUCTURED_ALLOWED_MIMETYPES"] = DEFAULT_MIMETYPES
9181

9282

93-
def get_pdf_splits(pdf_pages, split_size=1):
83+
def is_expected_response_type(
84+
media_type: str, response_type: Union[str, Dict[Any, Any], List[Any]]
85+
) -> bool:
86+
if media_type == "application/json" and response_type not in [dict, list]: # type: ignore
87+
return True
88+
elif media_type == "text/csv" and response_type != str:
89+
return True
90+
else:
91+
return False
92+
93+
94+
logger = logging.getLogger("unstructured_api")
95+
96+
97+
def get_pdf_splits(pdf_pages: List[PageObject], split_size: int = 1):
9498
"""
9599
Given a pdf (PdfReader) with n pages, split it into pdfs each with split_size # of pages
96100
Return the files with their page offset in the form [( BytesIO, int)]
@@ -113,8 +117,8 @@ def get_pdf_splits(pdf_pages, split_size=1):
113117

114118

115119
# Do not retry with these status codes
116-
def is_non_retryable(e):
117-
return 400 <= e.status_code < 500
120+
def is_non_retryable(e: Exception) -> bool:
121+
return 400 <= e.status_code < 500 # type: ignore
118122

119123

120124
@backoff.on_exception(
@@ -124,7 +128,14 @@ def is_non_retryable(e):
124128
giveup=is_non_retryable,
125129
logger=logger,
126130
)
127-
def call_api(request_url, api_key, filename, file, content_type, **partition_kwargs):
131+
def call_api(
132+
request_url: str,
133+
api_key: str,
134+
filename: str,
135+
file: IO[bytes],
136+
content_type: str,
137+
**partition_kwargs: Dict[str, Any],
138+
):
128139
"""
129140
Call the api with the given request_url.
130141
"""
@@ -144,7 +155,13 @@ def call_api(request_url, api_key, filename, file, content_type, **partition_kwa
144155
return response.text
145156

146157

147-
def partition_file_via_api(file_tuple, request, filename, content_type, **partition_kwargs):
158+
def partition_file_via_api(
159+
file_tuple: Tuple[Any, Any],
160+
request: Request,
161+
filename: str,
162+
content_type: str,
163+
**partition_kwargs: Dict[str, Any],
164+
):
148165
"""
149166
Send the given file to be partitioned remotely with retry logic,
150167
where the remote url is set by env var.
@@ -163,7 +180,7 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit
163180

164181
api_key = request.headers.get("unstructured-api-key")
165182

166-
result = call_api(request_url, api_key, filename, file, content_type, **partition_kwargs)
183+
result = call_api(request_url, api_key, filename, file, content_type, **partition_kwargs) # type: ignore
167184
elements = elements_from_json(text=result)
168185

169186
# We need to account for the original page numbers
@@ -176,8 +193,14 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit
176193

177194

178195
def partition_pdf_splits(
179-
request, pdf_pages, file, metadata_filename, content_type, coordinates, **partition_kwargs
180-
):
196+
request: Request,
197+
pdf_pages: List[PageObject],
198+
file: IO[bytes],
199+
metadata_filename: str,
200+
content_type: str,
201+
coordinates: bool,
202+
**partition_kwargs: Dict[str, Any],
203+
) -> List[Element]:
181204
"""
182205
Split a pdf into chunks and process in parallel with more api calls, or partition
183206
locally if the chunk is small enough. As soon as any remote call fails, bubble up
@@ -204,7 +227,7 @@ def partition_pdf_splits(
204227
**partition_kwargs,
205228
)
206229

207-
results = []
230+
results: List[Element] = []
208231
page_iterator = get_pdf_splits(pdf_pages, split_size=pages_per_pdf)
209232

210233
partition_func = partial(
@@ -224,9 +247,6 @@ def partition_pdf_splits(
224247
return results
225248

226249

227-
IS_CHIPPER_PROCESSING = False
228-
229-
230250
class ChipperMemoryProtection:
231251
"""
232252
Chipper calls are expensive, and right now we can only do one call at a time.
@@ -251,34 +271,34 @@ def __exit__(self, exc_type, exc_value, exc_tb):
251271

252272

253273
def pipeline_api(
254-
file,
255-
request=None,
256-
filename="",
257-
file_content_type=None,
258-
response_type="application/json",
259-
m_coordinates=[],
260-
m_encoding=[],
261-
m_hi_res_model_name=[],
262-
m_include_page_breaks=[],
263-
m_ocr_languages=None,
264-
m_pdf_infer_table_structure=[],
265-
m_skip_infer_table_types=[],
266-
m_strategy=[],
267-
m_xml_keep_tags=[],
268-
languages=None,
269-
m_chunking_strategy=[],
270-
m_multipage_sections=[],
271-
m_combine_under_n_chars=[],
272-
m_new_after_n_chars=[],
273-
m_max_characters=[],
274+
file: Optional[IO[bytes]],
275+
request: Request,
276+
filename: Union[str, None] = "",
277+
file_content_type: Union[str, None] = None,
278+
response_type: str = "application/json",
279+
m_coordinates: List[str] = [],
280+
m_encoding: List[str] = [],
281+
m_hi_res_model_name: List[str] = [],
282+
m_include_page_breaks: List[str] = [],
283+
m_ocr_languages: Union[List[str], None] = None,
284+
m_pdf_infer_table_structure: List[str] = [],
285+
m_skip_infer_table_types: List[str] = [],
286+
m_strategy: List[str] = [],
287+
m_xml_keep_tags: List[str] = [],
288+
languages: Union[List[str], None] = None,
289+
m_chunking_strategy: List[str] = [],
290+
m_multipage_sections: List[str] = [],
291+
m_combine_under_n_chars: List[str] = [],
292+
m_new_after_n_chars: List[str] = [],
293+
m_max_characters: List[str] = [],
274294
):
275-
if filename.endswith(".msg"):
295+
if filename and filename.endswith(".msg"):
276296
# Note(yuming): convert file type for msg files
277297
# since fast api might sent the wrong one.
278298
file_content_type = "application/x-ole-storage"
279299

280300
# We don't want to keep logging the same params for every parallel call
281-
origin_ip = request.headers.get("X-Forwarded-For") or request.client.host
301+
origin_ip = request.headers.get("X-Forwarded-For") or request.client.host # type: ignore
282302
is_internal_request = origin_ip.startswith("10.")
283303

284304
if not is_internal_request:
@@ -313,11 +333,10 @@ def pipeline_api(
313333

314334
_check_free_memory()
315335

316-
if file_content_type == "application/pdf":
317-
pdf = _check_pdf(file)
336+
pdf = _check_pdf(file) if file and file_content_type == "application/pdf" else None
318337

319338
show_coordinates_str = (m_coordinates[0] if len(m_coordinates) else "false").lower()
320-
show_coordinates = show_coordinates_str == "true"
339+
show_coordinates: bool = show_coordinates_str == "true"
321340

322341
hi_res_model_name = _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates)
323342
strategy = _validate_strategy(m_strategy)
@@ -394,7 +413,7 @@ def pipeline_api(
394413
)
395414
)
396415

397-
partition_kwargs = {
416+
partition_kwargs: Dict[str, Any] = {
398417
"file": file,
399418
"metadata_filename": filename,
400419
"content_type": file_content_type,
@@ -413,8 +432,9 @@ def pipeline_api(
413432
"new_after_n_chars": new_after_n_chars,
414433
"max_characters": max_characters,
415434
}
435+
elements: List[Element]
416436

417-
if file_content_type == "application/pdf" and pdf_parallel_mode_enabled:
437+
if file_content_type == "application/pdf" and pdf_parallel_mode_enabled and pdf:
418438
# Be careful of naming differences in api params vs partition params!
419439
# These kwargs are going back into the api, not into partition
420440
# They need to be switched back in partition_pdf_splits
@@ -516,25 +536,25 @@ def _check_free_memory():
516536
)
517537

518538

519-
def _check_pdf(file):
539+
def _check_pdf(file: Union[str, IO[Any]]) -> PdfReader:
520540
"""Check if the PDF file is encrypted, otherwise assume it is not a valid PDF."""
521541
try:
522-
pdf = PdfReader(file)
542+
pdf = PdfReader(stream=file) # StrByteType can be str or IO[Any]
523543

524544
# This will raise if the file is encrypted
525545
pdf.metadata
526546
return pdf
527-
except pypdf.errors.FileNotDecryptedError:
547+
except pypdf.errors.FileNotDecryptedError: # type: ignore
528548
raise HTTPException(
529549
status_code=400,
530550
detail="File is encrypted. Please decrypt it with password.",
531551
)
532-
except pypdf.errors.PdfReadError:
552+
except pypdf.errors.PdfReadError: # type: ignore
533553
raise HTTPException(status_code=422, detail="File does not appear to be a valid PDF")
534554

535555

536-
def _validate_strategy(m_strategy):
537-
strategy = (m_strategy[0] if len(m_strategy) else "auto").lower()
556+
def _validate_strategy(m_strategy: List[str]) -> str:
557+
strategy: str = (m_strategy[0] if len(m_strategy) else "auto").lower()
538558
strategies = ["fast", "hi_res", "auto", "ocr_only"]
539559
if strategy not in strategies:
540560
raise HTTPException(
@@ -543,7 +563,9 @@ def _validate_strategy(m_strategy):
543563
return strategy
544564

545565

546-
def _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates):
566+
def _validate_hi_res_model_name(
567+
m_hi_res_model_name: List[str], show_coordinates: bool
568+
) -> Union[str, None]:
547569
hi_res_model_name = m_hi_res_model_name[0] if len(m_hi_res_model_name) else None
548570

549571
# Make sure chipper aliases to the latest model
@@ -558,7 +580,7 @@ def _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates):
558580
return hi_res_model_name
559581

560582

561-
def _validate_chunking_strategy(m_chunking_strategy):
583+
def _validate_chunking_strategy(m_chunking_strategy: List[str]) -> Union[str, None]:
562584
chunking_strategy = m_chunking_strategy[0].lower() if len(m_chunking_strategy) else None
563585
chunk_strategies = ["by_title"]
564586
if chunking_strategy and (chunking_strategy not in chunk_strategies):
@@ -569,7 +591,7 @@ def _validate_chunking_strategy(m_chunking_strategy):
569591
return chunking_strategy
570592

571593

572-
def _set_pdf_infer_table_structure(m_pdf_infer_table_structure, strategy):
594+
def _set_pdf_infer_table_structure(m_pdf_infer_table_structure: List[str], strategy: str):
573595
pdf_infer_table_structure = (
574596
m_pdf_infer_table_structure[0] if len(m_pdf_infer_table_structure) else "false"
575597
).lower()
@@ -580,7 +602,7 @@ def _set_pdf_infer_table_structure(m_pdf_infer_table_structure, strategy):
580602
return pdf_infer_table_structure
581603

582604

583-
def get_validated_mimetype(file):
605+
def get_validated_mimetype(file: UploadFile):
584606
"""
585607
Return a file's mimetype, either via the file.content_type or the mimetypes lib if that's too
586608
generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and
@@ -591,7 +613,7 @@ def get_validated_mimetype(file):
591613
content_type = mimetypes.guess_type(str(file.filename))[0]
592614

593615
# Some filetypes missing for this library, just hardcode them for now
594-
if not content_type:
616+
if not content_type and file.filename:
595617
if file.filename.endswith(".md"):
596618
content_type = "text/markdown"
597619
elif file.filename.endswith(".msg"):
@@ -613,7 +635,7 @@ def get_validated_mimetype(file):
613635
class MultipartMixedResponse(StreamingResponse):
614636
CRLF = b"\r\n"
615637

616-
def __init__(self, *args, content_type: str = None, **kwargs):
638+
def __init__(self, *args: Any, content_type: Union[str, None] = None, **kwargs: Dict[str, Any]):
617639
super().__init__(*args, **kwargs)
618640
self.content_type = content_type
619641

@@ -627,15 +649,18 @@ def init_headers(self, headers: Optional[Mapping[str, str]] = None) -> None:
627649
def boundary(self):
628650
return b"--" + self.boundary_value.encode()
629651

630-
def _build_part_headers(self, headers: dict) -> bytes:
652+
def _build_part_headers(self, headers: Dict[str, str]) -> bytes:
631653
header_bytes = b""
632654
for header, value in headers.items():
633655
header_bytes += f"{header}: {value}".encode() + self.CRLF
634656
return header_bytes
635657

636658
def build_part(self, chunk: bytes) -> bytes:
637659
part = self.boundary + self.CRLF
638-
part_headers = {"Content-Length": len(chunk), "Content-Transfer-Encoding": "base64"}
660+
part_headers: Dict[str, Any] = {
661+
"Content-Length": len(chunk),
662+
"Content-Transfer-Encoding": "base64",
663+
}
639664
if self.content_type is not None:
640665
part_headers["Content-Type"] = self.content_type
641666
part += self._build_part_headers(part_headers)
@@ -661,8 +686,10 @@ async def stream_response(self, send: Send) -> None:
661686
await send({"type": "http.response.body", "body": b"", "more_body": False})
662687

663688

664-
def ungz_file(file: UploadFile, gz_uncompressed_content_type=None) -> UploadFile:
665-
def return_content_type(filename):
689+
def ungz_file(
690+
file: UploadFile, gz_uncompressed_content_type: Union[str, None] = None
691+
) -> UploadFile:
692+
def return_content_type(filename: str):
666693
if gz_uncompressed_content_type:
667694
return gz_uncompressed_content_type
668695
else:
@@ -740,7 +767,7 @@ def pipeline_1(
740767
status_code=status.HTTP_406_NOT_ACCEPTABLE,
741768
)
742769

743-
def response_generator(is_multipart):
770+
def response_generator(is_multipart: bool):
744771
for file in files:
745772
file_content_type = get_validated_mimetype(file)
746773

@@ -768,8 +795,7 @@ def response_generator(is_multipart):
768795
m_new_after_n_chars=new_after_n_chars,
769796
m_max_characters=max_characters,
770797
)
771-
772-
if is_expected_response_type(media_type, type(response)):
798+
if is_expected_response_type(media_type, type(response)): # type: ignore
773799
raise HTTPException(
774800
detail=(
775801
f"Conflict in media type {media_type}"
@@ -792,7 +818,7 @@ def response_generator(is_multipart):
792818
status_code=status.HTTP_406_NOT_ACCEPTABLE,
793819
)
794820

795-
def join_responses(responses):
821+
def join_responses(responses: List[Any]):
796822
if media_type != "text/csv":
797823
return responses
798824
data = pd.read_csv(io.BytesIO(responses[0].body))

0 commit comments

Comments
 (0)