1919from concurrent .futures import ThreadPoolExecutor
2020from functools import partial
2121import pypdf
22- from pypdf import PdfReader , PdfWriter
22+ from pypdf import PdfReader , PdfWriter , PageObject
2323import psutil
2424import requests
2525import backoff
26- from typing import Optional , Mapping
26+ from typing import Any , Dict , IO , List , Mapping , Optional , Tuple
2727from fastapi import (
2828 status ,
2929 FastAPI ,
4040import secrets
4141
4242# Unstructured Imports
43+ from unstructured .documents .elements import Element
4344from unstructured .partition .auto import partition
4445from unstructured .staging .base import (
4546 convert_to_isd ,
5354app = FastAPI ()
5455router = APIRouter ()
5556
56-
57- def is_expected_response_type (media_type , response_type ):
58- if media_type == "application/json" and response_type not in [dict , list ]:
59- return True
60- elif media_type == "text/csv" and response_type != str :
61- return True
62- else :
63- return False
64-
65-
66- logger = logging .getLogger ("unstructured_api" )
67-
57+ IS_CHIPPER_PROCESSING = False
6858
6959DEFAULT_MIMETYPES = (
7060 "application/pdf,application/msword,image/jpeg,image/png,text/markdown,"
@@ -90,7 +80,21 @@ def is_expected_response_type(media_type, response_type):
9080 os .environ ["UNSTRUCTURED_ALLOWED_MIMETYPES" ] = DEFAULT_MIMETYPES
9181
9282
93- def get_pdf_splits (pdf_pages , split_size = 1 ):
83+ def is_expected_response_type (
84+ media_type : str , response_type : Union [str , Dict [Any , Any ], List [Any ]]
85+ ) -> bool :
86+ if media_type == "application/json" and response_type not in [dict , list ]: # type: ignore
87+ return True
88+ elif media_type == "text/csv" and response_type != str :
89+ return True
90+ else :
91+ return False
92+
93+
94+ logger = logging .getLogger ("unstructured_api" )
95+
96+
97+ def get_pdf_splits (pdf_pages : List [PageObject ], split_size : int = 1 ):
9498 """
9599 Given a pdf (PdfReader) with n pages, split it into pdfs each with split_size # of pages
96100 Return the files with their page offset in the form [( BytesIO, int)]
@@ -113,8 +117,8 @@ def get_pdf_splits(pdf_pages, split_size=1):
113117
114118
115119# Do not retry with these status codes
116- def is_non_retryable (e ) :
117- return 400 <= e .status_code < 500
120+ def is_non_retryable (e : Exception ) -> bool :
121+ return 400 <= e .status_code < 500 # type: ignore
118122
119123
120124@backoff .on_exception (
@@ -124,7 +128,14 @@ def is_non_retryable(e):
124128 giveup = is_non_retryable ,
125129 logger = logger ,
126130)
127- def call_api (request_url , api_key , filename , file , content_type , ** partition_kwargs ):
131+ def call_api (
132+ request_url : str ,
133+ api_key : str ,
134+ filename : str ,
135+ file : IO [bytes ],
136+ content_type : str ,
137+ ** partition_kwargs : Dict [str , Any ],
138+ ):
128139 """
129140 Call the api with the given request_url.
130141 """
@@ -144,7 +155,13 @@ def call_api(request_url, api_key, filename, file, content_type, **partition_kwa
144155 return response .text
145156
146157
147- def partition_file_via_api (file_tuple , request , filename , content_type , ** partition_kwargs ):
158+ def partition_file_via_api (
159+ file_tuple : Tuple [Any , Any ],
160+ request : Request ,
161+ filename : str ,
162+ content_type : str ,
163+ ** partition_kwargs : Dict [str , Any ],
164+ ):
148165 """
149166 Send the given file to be partitioned remotely with retry logic,
150167 where the remote url is set by env var.
@@ -163,7 +180,7 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit
163180
164181 api_key = request .headers .get ("unstructured-api-key" )
165182
166- result = call_api (request_url , api_key , filename , file , content_type , ** partition_kwargs )
183+ result = call_api (request_url , api_key , filename , file , content_type , ** partition_kwargs ) # type: ignore
167184 elements = elements_from_json (text = result )
168185
169186 # We need to account for the original page numbers
@@ -176,8 +193,14 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit
176193
177194
178195def partition_pdf_splits (
179- request , pdf_pages , file , metadata_filename , content_type , coordinates , ** partition_kwargs
180- ):
196+ request : Request ,
197+ pdf_pages : List [PageObject ],
198+ file : IO [bytes ],
199+ metadata_filename : str ,
200+ content_type : str ,
201+ coordinates : bool ,
202+ ** partition_kwargs : Dict [str , Any ],
203+ ) -> List [Element ]:
181204 """
182205 Split a pdf into chunks and process in parallel with more api calls, or partition
183206 locally if the chunk is small enough. As soon as any remote call fails, bubble up
@@ -204,7 +227,7 @@ def partition_pdf_splits(
204227 ** partition_kwargs ,
205228 )
206229
207- results = []
230+ results : List [ Element ] = []
208231 page_iterator = get_pdf_splits (pdf_pages , split_size = pages_per_pdf )
209232
210233 partition_func = partial (
@@ -224,9 +247,6 @@ def partition_pdf_splits(
224247 return results
225248
226249
227- IS_CHIPPER_PROCESSING = False
228-
229-
230250class ChipperMemoryProtection :
231251 """
232252 Chipper calls are expensive, and right now we can only do one call at a time.
@@ -251,34 +271,34 @@ def __exit__(self, exc_type, exc_value, exc_tb):
251271
252272
253273def pipeline_api (
254- file ,
255- request = None ,
256- filename = "" ,
257- file_content_type = None ,
258- response_type = "application/json" ,
259- m_coordinates = [],
260- m_encoding = [],
261- m_hi_res_model_name = [],
262- m_include_page_breaks = [],
263- m_ocr_languages = None ,
264- m_pdf_infer_table_structure = [],
265- m_skip_infer_table_types = [],
266- m_strategy = [],
267- m_xml_keep_tags = [],
268- languages = None ,
269- m_chunking_strategy = [],
270- m_multipage_sections = [],
271- m_combine_under_n_chars = [],
272- m_new_after_n_chars = [],
273- m_max_characters = [],
274+ file : Optional [ IO [ bytes ]] ,
275+ request : Request ,
276+ filename : Union [ str , None ] = "" ,
277+ file_content_type : Union [ str , None ] = None ,
278+ response_type : str = "application/json" ,
279+ m_coordinates : List [ str ] = [],
280+ m_encoding : List [ str ] = [],
281+ m_hi_res_model_name : List [ str ] = [],
282+ m_include_page_breaks : List [ str ] = [],
283+ m_ocr_languages : Union [ List [ str ], None ] = None ,
284+ m_pdf_infer_table_structure : List [ str ] = [],
285+ m_skip_infer_table_types : List [ str ] = [],
286+ m_strategy : List [ str ] = [],
287+ m_xml_keep_tags : List [ str ] = [],
288+ languages : Union [ List [ str ], None ] = None ,
289+ m_chunking_strategy : List [ str ] = [],
290+ m_multipage_sections : List [ str ] = [],
291+ m_combine_under_n_chars : List [ str ] = [],
292+ m_new_after_n_chars : List [ str ] = [],
293+ m_max_characters : List [ str ] = [],
274294):
275- if filename .endswith (".msg" ):
295+ if filename and filename .endswith (".msg" ):
276296 # Note(yuming): convert file type for msg files
277297 # since fast api might sent the wrong one.
278298 file_content_type = "application/x-ole-storage"
279299
280300 # We don't want to keep logging the same params for every parallel call
281- origin_ip = request .headers .get ("X-Forwarded-For" ) or request .client .host
301+ origin_ip = request .headers .get ("X-Forwarded-For" ) or request .client .host # type: ignore
282302 is_internal_request = origin_ip .startswith ("10." )
283303
284304 if not is_internal_request :
@@ -313,11 +333,10 @@ def pipeline_api(
313333
314334 _check_free_memory ()
315335
316- if file_content_type == "application/pdf" :
317- pdf = _check_pdf (file )
336+ pdf = _check_pdf (file ) if file and file_content_type == "application/pdf" else None
318337
319338 show_coordinates_str = (m_coordinates [0 ] if len (m_coordinates ) else "false" ).lower ()
320- show_coordinates = show_coordinates_str == "true"
339+ show_coordinates : bool = show_coordinates_str == "true"
321340
322341 hi_res_model_name = _validate_hi_res_model_name (m_hi_res_model_name , show_coordinates )
323342 strategy = _validate_strategy (m_strategy )
@@ -394,7 +413,7 @@ def pipeline_api(
394413 )
395414 )
396415
397- partition_kwargs = {
416+ partition_kwargs : Dict [ str , Any ] = {
398417 "file" : file ,
399418 "metadata_filename" : filename ,
400419 "content_type" : file_content_type ,
@@ -413,8 +432,9 @@ def pipeline_api(
413432 "new_after_n_chars" : new_after_n_chars ,
414433 "max_characters" : max_characters ,
415434 }
435+ elements : List [Element ]
416436
417- if file_content_type == "application/pdf" and pdf_parallel_mode_enabled :
437+ if file_content_type == "application/pdf" and pdf_parallel_mode_enabled and pdf :
418438 # Be careful of naming differences in api params vs partition params!
419439 # These kwargs are going back into the api, not into partition
420440 # They need to be switched back in partition_pdf_splits
@@ -516,25 +536,25 @@ def _check_free_memory():
516536 )
517537
518538
519- def _check_pdf (file ) :
539+ def _check_pdf (file : Union [ str , IO [ Any ]]) -> PdfReader :
520540 """Check if the PDF file is encrypted, otherwise assume it is not a valid PDF."""
521541 try :
522- pdf = PdfReader (file )
542+ pdf = PdfReader (stream = file ) # StrByteType can be str or IO[Any]
523543
524544 # This will raise if the file is encrypted
525545 pdf .metadata
526546 return pdf
527- except pypdf .errors .FileNotDecryptedError :
547+ except pypdf .errors .FileNotDecryptedError : # type: ignore
528548 raise HTTPException (
529549 status_code = 400 ,
530550 detail = "File is encrypted. Please decrypt it with password." ,
531551 )
532- except pypdf .errors .PdfReadError :
552+ except pypdf .errors .PdfReadError : # type: ignore
533553 raise HTTPException (status_code = 422 , detail = "File does not appear to be a valid PDF" )
534554
535555
536- def _validate_strategy (m_strategy ) :
537- strategy = (m_strategy [0 ] if len (m_strategy ) else "auto" ).lower ()
556+ def _validate_strategy (m_strategy : List [ str ]) -> str :
557+ strategy : str = (m_strategy [0 ] if len (m_strategy ) else "auto" ).lower ()
538558 strategies = ["fast" , "hi_res" , "auto" , "ocr_only" ]
539559 if strategy not in strategies :
540560 raise HTTPException (
@@ -543,7 +563,9 @@ def _validate_strategy(m_strategy):
543563 return strategy
544564
545565
546- def _validate_hi_res_model_name (m_hi_res_model_name , show_coordinates ):
566+ def _validate_hi_res_model_name (
567+ m_hi_res_model_name : List [str ], show_coordinates : bool
568+ ) -> Union [str , None ]:
547569 hi_res_model_name = m_hi_res_model_name [0 ] if len (m_hi_res_model_name ) else None
548570
549571 # Make sure chipper aliases to the latest model
@@ -558,7 +580,7 @@ def _validate_hi_res_model_name(m_hi_res_model_name, show_coordinates):
558580 return hi_res_model_name
559581
560582
561- def _validate_chunking_strategy (m_chunking_strategy ) :
583+ def _validate_chunking_strategy (m_chunking_strategy : List [ str ]) -> Union [ str , None ] :
562584 chunking_strategy = m_chunking_strategy [0 ].lower () if len (m_chunking_strategy ) else None
563585 chunk_strategies = ["by_title" ]
564586 if chunking_strategy and (chunking_strategy not in chunk_strategies ):
@@ -569,7 +591,7 @@ def _validate_chunking_strategy(m_chunking_strategy):
569591 return chunking_strategy
570592
571593
572- def _set_pdf_infer_table_structure (m_pdf_infer_table_structure , strategy ):
594+ def _set_pdf_infer_table_structure (m_pdf_infer_table_structure : List [ str ] , strategy : str ):
573595 pdf_infer_table_structure = (
574596 m_pdf_infer_table_structure [0 ] if len (m_pdf_infer_table_structure ) else "false"
575597 ).lower ()
@@ -580,7 +602,7 @@ def _set_pdf_infer_table_structure(m_pdf_infer_table_structure, strategy):
580602 return pdf_infer_table_structure
581603
582604
583- def get_validated_mimetype (file ):
605+ def get_validated_mimetype (file : UploadFile ):
584606 """
585607 Return a file's mimetype, either via the file.content_type or the mimetypes lib if that's too
586608 generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and
@@ -591,7 +613,7 @@ def get_validated_mimetype(file):
591613 content_type = mimetypes .guess_type (str (file .filename ))[0 ]
592614
593615 # Some filetypes missing for this library, just hardcode them for now
594- if not content_type :
616+ if not content_type and file . filename :
595617 if file .filename .endswith (".md" ):
596618 content_type = "text/markdown"
597619 elif file .filename .endswith (".msg" ):
@@ -613,7 +635,7 @@ def get_validated_mimetype(file):
613635class MultipartMixedResponse (StreamingResponse ):
614636 CRLF = b"\r \n "
615637
616- def __init__ (self , * args , content_type : str = None , ** kwargs ):
638+ def __init__ (self , * args : Any , content_type : Union [ str , None ] = None , ** kwargs : Dict [ str , Any ] ):
617639 super ().__init__ (* args , ** kwargs )
618640 self .content_type = content_type
619641
@@ -627,15 +649,18 @@ def init_headers(self, headers: Optional[Mapping[str, str]] = None) -> None:
627649 def boundary (self ):
628650 return b"--" + self .boundary_value .encode ()
629651
630- def _build_part_headers (self , headers : dict ) -> bytes :
652+ def _build_part_headers (self , headers : Dict [ str , str ] ) -> bytes :
631653 header_bytes = b""
632654 for header , value in headers .items ():
633655 header_bytes += f"{ header } : { value } " .encode () + self .CRLF
634656 return header_bytes
635657
636658 def build_part (self , chunk : bytes ) -> bytes :
637659 part = self .boundary + self .CRLF
638- part_headers = {"Content-Length" : len (chunk ), "Content-Transfer-Encoding" : "base64" }
660+ part_headers : Dict [str , Any ] = {
661+ "Content-Length" : len (chunk ),
662+ "Content-Transfer-Encoding" : "base64" ,
663+ }
639664 if self .content_type is not None :
640665 part_headers ["Content-Type" ] = self .content_type
641666 part += self ._build_part_headers (part_headers )
@@ -661,8 +686,10 @@ async def stream_response(self, send: Send) -> None:
661686 await send ({"type" : "http.response.body" , "body" : b"" , "more_body" : False })
662687
663688
664- def ungz_file (file : UploadFile , gz_uncompressed_content_type = None ) -> UploadFile :
665- def return_content_type (filename ):
689+ def ungz_file (
690+ file : UploadFile , gz_uncompressed_content_type : Union [str , None ] = None
691+ ) -> UploadFile :
692+ def return_content_type (filename : str ):
666693 if gz_uncompressed_content_type :
667694 return gz_uncompressed_content_type
668695 else :
@@ -740,7 +767,7 @@ def pipeline_1(
740767 status_code = status .HTTP_406_NOT_ACCEPTABLE ,
741768 )
742769
743- def response_generator (is_multipart ):
770+ def response_generator (is_multipart : bool ):
744771 for file in files :
745772 file_content_type = get_validated_mimetype (file )
746773
@@ -768,8 +795,7 @@ def response_generator(is_multipart):
768795 m_new_after_n_chars = new_after_n_chars ,
769796 m_max_characters = max_characters ,
770797 )
771-
772- if is_expected_response_type (media_type , type (response )):
798+ if is_expected_response_type (media_type , type (response )): # type: ignore
773799 raise HTTPException (
774800 detail = (
775801 f"Conflict in media type { media_type } "
@@ -792,7 +818,7 @@ def response_generator(is_multipart):
792818 status_code = status .HTTP_406_NOT_ACCEPTABLE ,
793819 )
794820
795- def join_responses (responses ):
821+ def join_responses (responses : List [ Any ] ):
796822 if media_type != "text/csv" :
797823 return responses
798824 data = pd .read_csv (io .BytesIO (responses [0 ].body ))
0 commit comments