Skip to content

Commit 1f60d91

Browse files
committed
add demo_zh.png, html output and latex code postprocess
1 parent ec2881d commit 1f60d91

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

demo/demo.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
import torch
23
import argparse
34

@@ -9,23 +10,38 @@ def parse_config():
910
parser = argparse.ArgumentParser(description='arg parser')
1011
parser.add_argument('--image_path', type=str, default='demo.png', help='data path for table image')
1112
parser.add_argument('--ckpt_path', type=str, default='', help='ckpt path for table model')
13+
parser.add_argument('--cpu', action='store_true', default=False, help='using cpu for inference')
14+
parser.add_argument('--html', action='store_true', default=False, help='output html format table code')
1215
args = parser.parse_args()
1316
return args
1417

1518
def main():
1619
args = parse_config()
20+
if args.html:
21+
from pypandoc import convert_text
1722

1823
# build model
19-
model = build_model(args.ckpt_path, max_new_tokens=4096, max_time=120)
24+
model = build_model(args.ckpt_path, max_new_tokens=4096, max_time=120, use_gpu=(not args.cpu))
25+
if not args.cpu:
26+
model = model.cuda()
2027

2128
# model inference
2229
raw_image = Image.open(args.image_path)
30+
31+
start_time = time.time()
2332
with torch.no_grad():
2433
output = model(raw_image)
2534

2635
# show output latex code of table
36+
cost_time = time.time() - start_time
37+
print(f"total cost time: {cost_time:.2f}s")
2738
for i, latex_code in enumerate(output):
28-
print(f"Table {i}:\n{latex_code}")
39+
if args.html:
40+
html_code = convert_text(latex_code, 'html', format='latex')
41+
print(f"Table {i} HTML code:\n{html_code}")
42+
else:
43+
print(f"Table {i} LaTex code:\n{latex_code}")
44+
2945

3046
if __name__ == '__main__':
3147
main()

demo/demo_zh.png

35 KB
Loading

struct_eqtable/model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
1+
import re
12
import torch
23

34
from torch import nn
45
from transformers import AutoModelForVision2Seq, AutoProcessor
56

67

78
class StructTable(nn.Module):
8-
def __init__(self, model_path, max_new_tokens=2048, max_time=60):
9+
def __init__(self, model_path, max_new_tokens=2048, max_time=60, use_gpu=True):
910
super().__init__()
1011
self.model_path = model_path
1112
self.max_new_tokens = max_new_tokens
1213
self.max_generate_time = max_time
14+
self.use_gpu = use_gpu
1315

1416
# init model and image processor from ckpt path
1517
self.init_image_processor(model_path)
1618
self.init_model(model_path)
17-
19+
20+
self.special_str_list = ['\\midrule', '\\hline']
21+
22+
def postprocess_latex_code(self, code):
23+
for special_str in self.special_str_list:
24+
code = code.replace(special_str, special_str + ' ')
25+
return code
26+
1827
def init_model(self, model_path):
1928
self.model = AutoModelForVision2Seq.from_pretrained(model_path)
2029
self.model.eval()
21-
30+
2231
def init_image_processor(self, image_processor_path):
2332
self.data_processor = AutoProcessor.from_pretrained(image_processor_path)
2433

@@ -28,6 +37,9 @@ def forward(self, image):
2837
images=image,
2938
return_tensors='pt',
3039
)
40+
if self.use_gpu:
41+
for k, v in image_tokens.items():
42+
image_tokens[k] = v.cuda()
3143

3244
# generate text from image tokens
3345
model_output = self.model.generate(
@@ -37,5 +49,8 @@ def forward(self, image):
3749
max_time=self.max_generate_time
3850
)
3951
latex_codes = self.data_processor.batch_decode(model_output, skip_special_tokens=True)
52+
# postprocess
53+
for i, code in enumerate(latex_codes):
54+
latex_codes[i] = self.postprocess_latex_code(code)
4055

4156
return latex_codes

0 commit comments

Comments
 (0)