Skip to content

Commit 85767e5

Browse files
committed
Changed file detection to magic numbers and removed unnecessary libraries and code
1 parent cc66890 commit 85767e5

File tree

1 file changed

+42
-30
lines changed

1 file changed

+42
-30
lines changed

fastchat/serve/gradio_block_arena_vision_anony.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
import json
7-
import subprocess
87
import time
98

109
import gradio as gr
@@ -13,6 +12,8 @@
1312

1413
import os
1514
import PyPDF2
15+
import nest_asyncio
16+
from llama_parse import LlamaParse
1617

1718
from fastchat.constants import (
1819
TEXT_MODERATION_MSG,
@@ -246,22 +247,37 @@ def clear_history(request: gr.Request):
246247
+ [""]
247248
)
248249

249-
def extract_text_from_pdf(pdf_file_path):
250-
"""Extract text from a PDF file."""
250+
def is_pdf(file_path):
251251
try:
252-
with open(pdf_file_path, 'rb') as f:
253-
reader = PyPDF2.PdfReader(f)
254-
pdf_text = ""
255-
for page in reader.pages:
256-
pdf_text += page.extract_text()
257-
return pdf_text
252+
with open(file_path, 'rb') as file:
253+
header = file.read(5) # Read the first 5 bytes
254+
return header == b'%PDF-'
258255
except Exception as e:
259-
logger.error(f"Failed to extract text from PDF: {e}")
260-
return None
261-
262-
import os
263-
import nest_asyncio
264-
from llama_parse import LlamaParse
256+
print(f"Error: {e}")
257+
return False
258+
259+
def is_image(file_path):
260+
magic_numbers = {
261+
b'\xff\xd8\xff': 'JPEG',
262+
b'\x89PNG\r\n\x1a\n': 'PNG',
263+
b'GIF87a': 'GIF',
264+
b'GIF89a': 'GIF',
265+
b'BM': 'BMP',
266+
b'\x00\x00\x01\x00': 'ICO',
267+
b'\x49\x49\x2a\x00': 'TIFF',
268+
b'\x4d\x4d\x00\x2a': 'TIFF',
269+
b'RIFF': 'WebP',
270+
}
271+
try:
272+
with open(file_path, 'rb') as file:
273+
header = file.read(8) # Read the first 8 bytes
274+
for magic in magic_numbers:
275+
if header.startswith(magic):
276+
return True
277+
return False
278+
except Exception as e:
279+
print(f"Error reading file: {e}")
280+
return False
265281

266282
nest_asyncio.apply() # Ensure compatibility with async environments
267283

@@ -278,14 +294,7 @@ def pdf_parse(pdf_path):
278294
language="en" # Set language (default is English)
279295
)
280296

281-
# Prepare the output directory and file name
282-
output_dir = "outputs"
283-
os.makedirs(output_dir, exist_ok=True)
284-
285297
pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
286-
markdown_file_path = os.path.join(output_dir, f"{pdf_name}.md")
287-
288-
# Load and parse the PDF
289298
extra_info = {"file_name": pdf_name}
290299

291300
with open(pdf_path, "rb") as pdf_file:
@@ -326,10 +335,14 @@ def add_text(
326335
else:
327336
text = chat_input
328337
files = []
329-
338+
330339
images = []
331340

332-
file_extension = os.path.splitext(files[0])[1].lower()
341+
# currently support up to one pdf or one image
342+
# if is_pdf(files[0]):
343+
# pdfs = files
344+
if is_image(files[0]):
345+
images = files
333346

334347
ip = get_ip(request)
335348
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
@@ -340,7 +353,7 @@ def add_text(
340353
if states[0] is None:
341354
assert states[1] is None
342355

343-
if len(files) > 0 and file_extension != ".pdf":
356+
if len(files) > 0 and is_image(files[0]):
344357
model_left, model_right = get_battle_pair(
345358
context.all_vision_models,
346359
VISION_BATTLE_TARGETS,
@@ -423,7 +436,7 @@ def add_text(
423436
+ [""]
424437
)
425438

426-
if file_extension != ".pdf":
439+
if is_image(files[0]):
427440
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
428441
for i in range(num_sides):
429442
post_processed_text = _prepare_text_with_image(
@@ -438,9 +451,9 @@ def add_text(
438451
if "deluxe" in states[i].model_name:
439452
hint_msg = SLOW_MODEL_MSG
440453

441-
if file_extension == ".pdf":
454+
if is_pdf(files[0]):
442455
document_text = pdf_parse(files[0])
443-
post_processed_text = f"""
456+
prompt_text = f"""
444457
The following is the content of a document:
445458
446459
{document_text}
@@ -449,8 +462,7 @@ def add_text(
449462
450463
{text}
451464
"""
452-
453-
post_processed_text = wrap_query_context(text, post_processed_text)
465+
post_processed_text = wrap_query_context(text, prompt_text)
454466

455467
# text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
456468
for i in range(num_sides):

0 commit comments

Comments
 (0)