Skip to content

Commit 5c52665

Browse files
committed
support multimodal pdfchat and switch to marker pdf
1 parent e4c0f3b commit 5c52665

File tree

3 files changed

+139
-33
lines changed

3 files changed

+139
-33
lines changed

fastchat/conversation.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,18 +362,39 @@ def update_last_message(self, message: str):
362362
def to_gradio_chatbot(self):
363363
"""Convert the conversation to gradio chatbot format."""
364364
from fastchat.serve.vision.image import ImageFormat
365+
import re
365366

366367
ret = []
367368
for i, (role, msg) in enumerate(self.messages[self.offset :]):
368369
if i % 2 == 0:
369370
if type(msg) is tuple:
370371
msg, images = msg
371-
image = images[0] # Only one image on gradio at one time
372-
if image.image_format == ImageFormat.URL:
373-
img_str = f'<img src="{image.url}" alt="user upload image" />'
374-
elif image.image_format == ImageFormat.BYTES:
375-
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
376-
msg = img_str + msg.replace("<image>\n", "").strip()
372+
373+
pattern = re.compile("!\[\]\(_page_\d_Figure_\d\.jpeg\)")
374+
embed_locations = pattern.findall(msg)
375+
376+
pdfchat = False
377+
for i, embed_str in enumerate(embed_locations):
378+
if i >= len(images):
379+
break
380+
381+
image = images[i]
382+
msg = msg.replace(
383+
embed_str,
384+
f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="document image" />',
385+
)
386+
pdfchat = True
387+
388+
if not pdfchat:
389+
# vision arena only supports one image on gradio at one time
390+
image = images[0]
391+
if image.image_format == ImageFormat.URL:
392+
img_str = (
393+
f'<img src="{image.url}" alt="user upload image" />'
394+
)
395+
elif image.image_format == ImageFormat.BYTES:
396+
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
397+
msg = img_str + msg.replace("<image>\n", "").strip()
377398

378399
ret.append([msg, None])
379400
else:

fastchat/serve/gradio_block_arena_vision.py

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,51 @@ def wrap_pdfchat_query(query, document):
232232
return reformatted_query_context
233233

234234

235-
LLAMA_PARSE_MAX_RETRY = 2
236-
LLAMAPARSE_SUPPORTED_LANGS = {
235+
# LLAMA_PARSE_MAX_RETRY = 2
236+
# LLAMAPARSE_SUPPORTED_LANGS = {
237+
# "English": "en",
238+
# "Chinese": "ch_sim",
239+
# "Russian": "ru",
240+
# "Spanish": "es",
241+
# "Japanese": "ja",
242+
# "Korean": "ko",
243+
# "French": "fr",
244+
# "German": "de",
245+
# "Vietnamese": "vi",
246+
# }
247+
248+
249+
# def parse_pdf(file_path):
250+
# from llama_parse import LlamaParse
251+
252+
# assert (
253+
# "LLAMA_CLOUD_API_KEY" in os.environ
254+
# ), "Make sure to specify LlamaParse API key."
255+
256+
# for _ in range(LLAMA_PARSE_MAX_RETRY):
257+
# try:
258+
# documents = LlamaParse(
259+
# result_type="markdown",
260+
# verbose=True,
261+
# languages=list(LLAMAPARSE_SUPPORTED_LANGS.values()),
262+
# accurate_mode=True,
263+
# ).load_data(file_path)
264+
# assert len(documents) > 0
265+
# break
266+
# except AssertionError as e:
267+
# continue
268+
269+
# output = "\n".join(
270+
# [f"Page {i+1}:\n{doc.text}\n" for i, doc in enumerate(documents)]
271+
# )
272+
273+
# return output
274+
275+
276+
PDFPARSE_MAX_RETRY = 2
277+
PDFPARSE_SUPPORTED_LANGS = {
237278
"English": "en",
238-
"Chinese": "ch_sim",
279+
"Chinese": "zh",
239280
"Russian": "ru",
240281
"Spanish": "es",
241282
"Japanese": "ja",
@@ -244,33 +285,36 @@ def wrap_pdfchat_query(query, document):
244285
"German": "de",
245286
"Vietnamese": "vi",
246287
}
288+
MARKER_PDFPARSE_CONFIG = {
289+
"output_format": "markdown",
290+
"languages": ",".join(PDFPARSE_SUPPORTED_LANGS.values()),
291+
}
247292

248293

249294
def parse_pdf(file_path):
250-
from llama_parse import LlamaParse
251-
252-
assert (
253-
"LLAMA_CLOUD_API_KEY" in os.environ
254-
), "Make sure to specify LlamaParse API key."
295+
from marker.config.parser import ConfigParser
296+
from marker.models import create_model_dict
297+
from marker.converters.pdf import PdfConverter
255298

256-
for _ in range(LLAMA_PARSE_MAX_RETRY):
299+
output_md, output_images = None, None
300+
for _ in range(PDFPARSE_MAX_RETRY):
257301
try:
258-
documents = LlamaParse(
259-
result_type="markdown",
260-
verbose=True,
261-
languages=list(LLAMAPARSE_SUPPORTED_LANGS.values()),
262-
accurate_mode=True,
263-
).load_data(file_path)
264-
assert len(documents) > 0
302+
config_parser = ConfigParser(MARKER_PDFPARSE_CONFIG)
303+
304+
converter = PdfConverter(
305+
config=config_parser.generate_config_dict(),
306+
artifact_dict=create_model_dict(),
307+
processor_list=config_parser.get_processors(),
308+
renderer=config_parser.get_renderer(),
309+
)
310+
rendered = converter(file_path)
311+
output_md = rendered.markdown
312+
output_images = list(rendered.images.values())
265313
break
266314
except AssertionError as e:
267315
continue
268316

269-
output = "\n".join(
270-
[f"Page {i+1}:\n{doc.text}\n" for i, doc in enumerate(documents)]
271-
)
272-
273-
return output
317+
return output_md, output_images
274318

275319

276320
def _prepare_text_with_image(state, text, images, csam_flag):
@@ -284,12 +328,26 @@ def _prepare_text_with_image(state, text, images, csam_flag):
284328
return text
285329

286330

331+
# def _prepare_text_with_pdf(text, pdfs):
332+
# if len(pdfs) > 0:
333+
# document_content = parse_pdf(pdfs[0])
334+
# print("Document processed")
335+
# text = wrap_pdfchat_query(text, document_content)
336+
337+
# return text
338+
339+
287340
def _prepare_text_with_pdf(text, pdfs):
288341
if len(pdfs) > 0:
289-
document_content = parse_pdf(pdfs[0])
342+
parsed_text, imgs = parse_pdf(pdfs[0])
290343
print("Document processed")
291-
text = wrap_pdfchat_query(text, document_content)
344+
wrapped_text = wrap_pdfchat_query(text, parsed_text)
292345

346+
imgs = convert_pdf_images_to_conversation_format(imgs)
347+
348+
if len(imgs) > 0:
349+
return wrapped_text, imgs
350+
return wrapped_text
293351
return text
294352

295353

@@ -307,6 +365,20 @@ def convert_images_to_conversation_format(images):
307365
return conv_images
308366

309367

368+
def convert_pdf_images_to_conversation_format(images):
369+
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
370+
conv_images = []
371+
if len(images) > 0:
372+
for img in images:
373+
# pdf parser returns a PIL image object instead of path
374+
conv_images.append(
375+
Image(url="").to_conversation_format(
376+
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB, pil_img=img
377+
)
378+
)
379+
return conv_images
380+
381+
310382
def moderate_input(state, text, all_conv_text, model_list, images, ip):
311383
text_flagged = moderation_filter(all_conv_text, model_list)
312384
# flagged = moderation_filter(text, [state.model_name])

fastchat/serve/vision/image.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import base64
22
from enum import auto, IntEnum
33
from io import BytesIO
4-
54
from pydantic import BaseModel
65

76

@@ -115,11 +114,25 @@ def convert_url_to_image_bytes(self, max_image_size_mb):
115114

116115
return image_format, img_base64_str
117116

118-
def to_conversation_format(self, max_image_size_mb):
119-
image_format, image_bytes = self.convert_url_to_image_bytes(
120-
max_image_size_mb=max_image_size_mb
117+
def convert_pil_image_to_image_bytes(self, pil_img, max_image_size_mb):
118+
image_format, image_bytes = self.resize_image_and_return_image_in_bytes(
119+
pil_img, max_image_size_mb
121120
)
122121

122+
img_base64_str = base64.b64encode(image_bytes.getvalue()).decode()
123+
124+
return image_format, img_base64_str
125+
126+
def to_conversation_format(self, max_image_size_mb, pil_img=None):
127+
if pil_img:
128+
image_format, image_bytes = self.convert_pil_image_to_image_bytes(
129+
pil_img=pil_img, max_image_size_mb=max_image_size_mb
130+
)
131+
else:
132+
image_format, image_bytes = self.convert_url_to_image_bytes(
133+
max_image_size_mb=max_image_size_mb
134+
)
135+
123136
self.filetype = image_format
124137
self.image_format = ImageFormat.BYTES
125138
self.base64_str = image_bytes

0 commit comments

Comments
 (0)