1+ import re
12import torch
23
34from torch import nn
45from transformers import AutoModelForVision2Seq , AutoProcessor
56
67
78class StructTable (nn .Module ):
8- def __init__ (self , model_path , max_new_tokens = 2048 , max_time = 60 ):
9+ def __init__ (self , model_path , max_new_tokens = 2048 , max_time = 60 , use_gpu = True ):
910 super ().__init__ ()
1011 self .model_path = model_path
1112 self .max_new_tokens = max_new_tokens
1213 self .max_generate_time = max_time
14+ self .use_gpu = use_gpu
1315
1416 # init model and image processor from ckpt path
1517 self .init_image_processor (model_path )
1618 self .init_model (model_path )
17-
19+
20+ self .special_str_list = ['\\ midrule' , '\\ hline' ]
21+
22+ def postprocess_latex_code (self , code ):
23+ for special_str in self .special_str_list :
24+ code = code .replace (special_str , special_str + ' ' )
25+ return code
26+
1827 def init_model (self , model_path ):
1928 self .model = AutoModelForVision2Seq .from_pretrained (model_path )
2029 self .model .eval ()
21-
30+
2231 def init_image_processor (self , image_processor_path ):
2332 self .data_processor = AutoProcessor .from_pretrained (image_processor_path )
2433
@@ -28,6 +37,9 @@ def forward(self, image):
2837 images = image ,
2938 return_tensors = 'pt' ,
3039 )
40+ if self .use_gpu :
41+ for k , v in image_tokens .items ():
42+ image_tokens [k ] = v .cuda ()
3143
3244 # generate text from image tokens
3345 model_output = self .model .generate (
@@ -37,5 +49,8 @@ def forward(self, image):
3749 max_time = self .max_generate_time
3850 )
3951 latex_codes = self .data_processor .batch_decode (model_output , skip_special_tokens = True )
52+ # postprocess
53+ for i , code in enumerate (latex_codes ):
54+ latex_codes [i ] = self .postprocess_latex_code (code )
4055
4156 return latex_codes
0 commit comments