44"""
55
66import json
7+ import subprocess
78import time
89
910import gradio as gr
1011import numpy as np
1112from typing import Union
1213
14+ import os
15+ import PyPDF2
16+
1317from fastchat .constants import (
1418 TEXT_MODERATION_MSG ,
1519 IMAGE_MODERATION_MSG ,
@@ -242,6 +246,56 @@ def clear_history(request: gr.Request):
242246 + ["" ]
243247 )
244248
249+ def extract_text_from_pdf (pdf_file_path ):
250+ """Extract text from a PDF file."""
251+ try :
252+ with open (pdf_file_path , 'rb' ) as f :
253+ reader = PyPDF2 .PdfReader (f )
254+ pdf_text = ""
255+ for page in reader .pages :
256+ pdf_text += page .extract_text ()
257+ return pdf_text
258+ except Exception as e :
259+ logger .error (f"Failed to extract text from PDF: { e } " )
260+ return None
261+
262+ def llama_parse (pdf_path ):
263+ os .environ ['LLAMA_CLOUD_API_KEY' ] = 'LLAMA KEY'
264+
265+ output_dir = 'outputs'
266+ os .makedirs (output_dir , exist_ok = True )
267+
268+ pdf_name = os .path .splitext (os .path .basename (pdf_path ))[0 ]
269+ markdown_file_path = os .path .join (output_dir , f'{ pdf_name } .md' )
270+
271+ command = [
272+ 'llama-parse' ,
273+ pdf_path ,
274+ '--result-type' , 'markdown' ,
275+ '--output-file' , markdown_file_path
276+ ]
277+
278+ subprocess .run (command , check = True )
279+
280+ with open (markdown_file_path , 'r' , encoding = 'utf-8' ) as file :
281+ markdown_content = file .read ()
282+
283+ return markdown_content
284+
285+ def wrap_query_context (user_query , query_context ):
286+ #TODO: refactor to split up user query and query context.
287+ # lines = input.split("\n\n[USER QUERY]", 1)
288+ # user_query = lines[1].strip()
289+ # query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
290+ reformatted_query_context = (
291+ f"[QUERY CONTEXT]\n "
292+ f"<details>\n "
293+ f"<summary>Expand context details</summary>\n \n "
294+ f"{ query_context } \n \n "
295+ f"</details>"
296+ )
297+ markdown = reformatted_query_context + f"\n \n [USER QUERY]\n \n { user_query } "
298+ return markdown
245299
246300def add_text (
247301 state0 ,
@@ -253,10 +307,14 @@ def add_text(
253307 request : gr .Request ,
254308):
255309 if isinstance (chat_input , dict ):
256- text , images = chat_input ["text" ], chat_input ["files" ]
310+ text , files = chat_input ["text" ], chat_input ["files" ]
257311 else :
258312 text = chat_input
259- images = []
313+ files = []
314+
315+ images = []
316+
317+ file_extension = os .path .splitext (files [0 ])[1 ].lower ()
260318
261319 ip = get_ip (request )
262320 logger .info (f"add_text (anony). ip: { ip } . len: { len (text )} " )
@@ -267,7 +325,7 @@ def add_text(
267325 if states [0 ] is None :
268326 assert states [1 ] is None
269327
270- if len (images ) > 0 :
328+ if len (files ) > 0 and file_extension != ".pdf" :
271329 model_left , model_right = get_battle_pair (
272330 context .all_vision_models ,
273331 VISION_BATTLE_TARGETS ,
@@ -363,6 +421,27 @@ def add_text(
363421 for i in range (num_sides ):
364422 if "deluxe" in states [i ].model_name :
365423 hint_msg = SLOW_MODEL_MSG
424+
425+ if file_extension == ".pdf" :
426+ document_text = llama_parse (files [0 ])
427+ post_processed_text = f"""
428+ The following is the content of a document:
429+
430+ { document_text }
431+
432+ Based on this document, answer the following question:
433+
434+ { text }
435+ """
436+
437+ post_processed_text = wrap_query_context (text , post_processed_text )
438+
439+ text = text [:BLIND_MODE_INPUT_CHAR_LEN_LIMIT ] # Hard cut-off
440+ for i in range (num_sides ):
441+ states [i ].conv .append_message (states [i ].conv .roles [0 ], post_processed_text )
442+ states [i ].conv .append_message (states [i ].conv .roles [1 ], None )
443+ states [i ].skip_next = False
444+
366445 return (
367446 states
368447 + [x .to_gradio_chatbot () for x in states ]
@@ -471,10 +550,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
471550 )
472551
473552 multimodal_textbox = gr .MultimodalTextbox (
474- file_types = ["image " ],
553+ file_types = ["file " ],
475554 show_label = False ,
476555 container = True ,
477- placeholder = "Enter your prompt or add image here" ,
556+ placeholder = "Enter your prompt or add a PDF file here" ,
478557 elem_id = "input_box" ,
479558 scale = 3 ,
480559 )
@@ -483,6 +562,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
483562 )
484563
485564 with gr .Row () as button_row :
565+ random_btn = gr .Button (value = "🔮 Random Image" , interactive = True )
486566 if random_questions :
487567 global vqa_samples
488568 with open (random_questions , "r" ) as f :
0 commit comments