Skip to content

Commit 3857e22

Browse files
committed
added pdf context support
1 parent 1cd4b74 commit 3857e22

File tree

1 file changed

+85
-5
lines changed

1 file changed

+85
-5
lines changed

fastchat/serve/gradio_block_arena_vision_anony.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
"""
55

66
import json
7+
import subprocess
78
import time
89

910
import gradio as gr
1011
import numpy as np
1112
from typing import Union
1213

14+
import os
15+
import PyPDF2
16+
1317
from fastchat.constants import (
1418
TEXT_MODERATION_MSG,
1519
IMAGE_MODERATION_MSG,
@@ -242,6 +246,56 @@ def clear_history(request: gr.Request):
242246
+ [""]
243247
)
244248

249+
def extract_text_from_pdf(pdf_file_path):
250+
"""Extract text from a PDF file."""
251+
try:
252+
with open(pdf_file_path, 'rb') as f:
253+
reader = PyPDF2.PdfReader(f)
254+
pdf_text = ""
255+
for page in reader.pages:
256+
pdf_text += page.extract_text()
257+
return pdf_text
258+
except Exception as e:
259+
logger.error(f"Failed to extract text from PDF: {e}")
260+
return None
261+
262+
def llama_parse(pdf_path):
263+
os.environ['LLAMA_CLOUD_API_KEY'] = 'LLAMA KEY'
264+
265+
output_dir = 'outputs'
266+
os.makedirs(output_dir, exist_ok=True)
267+
268+
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
269+
markdown_file_path = os.path.join(output_dir, f'{pdf_name}.md')
270+
271+
command = [
272+
'llama-parse',
273+
pdf_path,
274+
'--result-type', 'markdown',
275+
'--output-file', markdown_file_path
276+
]
277+
278+
subprocess.run(command, check=True)
279+
280+
with open(markdown_file_path, 'r', encoding='utf-8') as file:
281+
markdown_content = file.read()
282+
283+
return markdown_content
284+
285+
def wrap_query_context(user_query, query_context):
286+
#TODO: refactor to split up user query and query context.
287+
# lines = input.split("\n\n[USER QUERY]", 1)
288+
# user_query = lines[1].strip()
289+
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
290+
reformatted_query_context = (
291+
f"[QUERY CONTEXT]\n"
292+
f"<details>\n"
293+
f"<summary>Expand context details</summary>\n\n"
294+
f"{query_context}\n\n"
295+
f"</details>"
296+
)
297+
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}"
298+
return markdown
245299

246300
def add_text(
247301
state0,
@@ -253,10 +307,14 @@ def add_text(
253307
request: gr.Request,
254308
):
255309
if isinstance(chat_input, dict):
256-
text, images = chat_input["text"], chat_input["files"]
310+
text, files = chat_input["text"], chat_input["files"]
257311
else:
258312
text = chat_input
259-
images = []
313+
files = []
314+
315+
images = []
316+
317+
file_extension = os.path.splitext(files[0])[1].lower()
260318

261319
ip = get_ip(request)
262320
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
@@ -267,7 +325,7 @@ def add_text(
267325
if states[0] is None:
268326
assert states[1] is None
269327

270-
if len(images) > 0:
328+
if len(files) > 0 and file_extension != ".pdf":
271329
model_left, model_right = get_battle_pair(
272330
context.all_vision_models,
273331
VISION_BATTLE_TARGETS,
@@ -363,6 +421,27 @@ def add_text(
363421
for i in range(num_sides):
364422
if "deluxe" in states[i].model_name:
365423
hint_msg = SLOW_MODEL_MSG
424+
425+
if file_extension == ".pdf":
426+
document_text = llama_parse(files[0])
427+
post_processed_text = f"""
428+
The following is the content of a document:
429+
430+
{document_text}
431+
432+
Based on this document, answer the following question:
433+
434+
{text}
435+
"""
436+
437+
post_processed_text = wrap_query_context(text, post_processed_text)
438+
439+
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
440+
for i in range(num_sides):
441+
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
442+
states[i].conv.append_message(states[i].conv.roles[1], None)
443+
states[i].skip_next = False
444+
366445
return (
367446
states
368447
+ [x.to_gradio_chatbot() for x in states]
@@ -471,10 +550,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
471550
)
472551

473552
multimodal_textbox = gr.MultimodalTextbox(
474-
file_types=["image"],
553+
file_types=["file"],
475554
show_label=False,
476555
container=True,
477-
placeholder="Enter your prompt or add image here",
556+
placeholder="Enter your prompt or add a PDF file here",
478557
elem_id="input_box",
479558
scale=3,
480559
)
@@ -483,6 +562,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
483562
)
484563

485564
with gr.Row() as button_row:
565+
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
486566
if random_questions:
487567
global vqa_samples
488568
with open(random_questions, "r") as f:

0 commit comments

Comments
 (0)