44"""
55
66import json
7- import subprocess
87import time
98
109import gradio as gr
1312
1413import os
1514import PyPDF2
15+ import nest_asyncio
16+ from llama_parse import LlamaParse
1617
1718from fastchat .constants import (
1819 TEXT_MODERATION_MSG ,
@@ -246,22 +247,37 @@ def clear_history(request: gr.Request):
246247 + ["" ]
247248 )
248249
249- def extract_text_from_pdf (pdf_file_path ):
250- """Extract text from a PDF file."""
250+ def is_pdf (file_path ):
251251 try :
252- with open (pdf_file_path , 'rb' ) as f :
253- reader = PyPDF2 .PdfReader (f )
254- pdf_text = ""
255- for page in reader .pages :
256- pdf_text += page .extract_text ()
257- return pdf_text
252+ with open (file_path , 'rb' ) as file :
253+ header = file .read (5 ) # Read the first 5 bytes
254+ return header == b'%PDF-'
258255 except Exception as e :
259- logger .error (f"Failed to extract text from PDF: { e } " )
260- return None
261-
262- import os
263- import nest_asyncio
264- from llama_parse import LlamaParse
256+ print (f"Error: { e } " )
257+ return False
258+
259+ def is_image (file_path ):
260+ magic_numbers = {
261+ b'\xff \xd8 \xff ' : 'JPEG' ,
262+ b'\x89 PNG\r \n \x1a \n ' : 'PNG' ,
263+ b'GIF87a' : 'GIF' ,
264+ b'GIF89a' : 'GIF' ,
265+ b'BM' : 'BMP' ,
266+ b'\x00 \x00 \x01 \x00 ' : 'ICO' ,
267+ b'\x49 \x49 \x2a \x00 ' : 'TIFF' ,
268+ b'\x4d \x4d \x00 \x2a ' : 'TIFF' ,
269+ b'RIFF' : 'WebP' ,
270+ }
271+ try :
272+ with open (file_path , 'rb' ) as file :
273+ header = file .read (8 ) # Read the first 8 bytes
274+ for magic in magic_numbers :
275+ if header .startswith (magic ):
276+ return True
277+ return False
278+ except Exception as e :
279+ print (f"Error reading file: { e } " )
280+ return False
265281
266282nest_asyncio .apply () # Ensure compatibility with async environments
267283
@@ -278,14 +294,7 @@ def pdf_parse(pdf_path):
278294 language = "en" # Set language (default is English)
279295 )
280296
281- # Prepare the output directory and file name
282- output_dir = "outputs"
283- os .makedirs (output_dir , exist_ok = True )
284-
285297 pdf_name = os .path .splitext (os .path .basename (pdf_path ))[0 ]
286- markdown_file_path = os .path .join (output_dir , f"{ pdf_name } .md" )
287-
288- # Load and parse the PDF
289298 extra_info = {"file_name" : pdf_name }
290299
291300 with open (pdf_path , "rb" ) as pdf_file :
@@ -326,10 +335,14 @@ def add_text(
326335 else :
327336 text = chat_input
328337 files = []
329-
338+
330339 images = []
331340
332- file_extension = os .path .splitext (files [0 ])[1 ].lower ()
341+ # currently support up to one pdf or one image
342+ # if is_pdf(files[0]):
343+ # pdfs = files
344+ if is_image (files [0 ]):
345+ images = files
333346
334347 ip = get_ip (request )
335348 logger .info (f"add_text (anony). ip: { ip } . len: { len (text )} " )
@@ -340,7 +353,7 @@ def add_text(
340353 if states [0 ] is None :
341354 assert states [1 ] is None
342355
343- if len (files ) > 0 and file_extension != ".pdf" :
356+ if len (files ) > 0 and is_image ( files [ 0 ]) :
344357 model_left , model_right = get_battle_pair (
345358 context .all_vision_models ,
346359 VISION_BATTLE_TARGETS ,
@@ -423,7 +436,7 @@ def add_text(
423436 + ["" ]
424437 )
425438
426- if file_extension != ".pdf" :
439+ if is_image ( files [ 0 ]) :
427440 text = text [:BLIND_MODE_INPUT_CHAR_LEN_LIMIT ] # Hard cut-off
428441 for i in range (num_sides ):
429442 post_processed_text = _prepare_text_with_image (
@@ -438,9 +451,9 @@ def add_text(
438451 if "deluxe" in states [i ].model_name :
439452 hint_msg = SLOW_MODEL_MSG
440453
441- if file_extension == ".pdf" :
454+ if is_pdf ( files [ 0 ]) :
442455 document_text = pdf_parse (files [0 ])
443- post_processed_text = f"""
456+ prompt_text = f"""
444457 The following is the content of a document:
445458
446459 { document_text }
@@ -449,8 +462,7 @@ def add_text(
449462
450463 { text }
451464 """
452-
453- post_processed_text = wrap_query_context (text , post_processed_text )
465+ post_processed_text = wrap_query_context (text , prompt_text )
454466
455467 # text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
456468 for i in range (num_sides ):
0 commit comments