88import numpy as np
99import torch
1010from PIL import Image
11- from transformers import AutoTokenizer
11+ from transformers import AutoTokenizer , StoppingCriteria , StoppingCriteriaList
1212
1313from docling_ibm_models .code_formula_model .models .sam_opt import SamOPTForCausalLM
1414from docling_ibm_models .code_formula_model .models .sam_opt_image_processor import (
1818_log = logging .getLogger (__name__ )
1919
2020
21+ class StopOnString (StoppingCriteria ):
22+ def __init__ (self , tokenizer , stop_string ):
23+ self .stop_token_ids = tokenizer .encode (stop_string , add_special_tokens = False )
24+
25+ def __call__ (self , input_ids , scores , ** kwargs ):
26+ for sequence in input_ids :
27+ sequence_list = sequence .tolist ()
28+ for i in range (len (sequence_list ) - len (self .stop_token_ids ) + 1 ):
29+ if (
30+ sequence_list [i : i + len (self .stop_token_ids )]
31+ == self .stop_token_ids
32+ ):
33+ return True
34+ return False
35+
36+
2137class CodeFormulaPredictor :
2238 """
2339 Code and Formula Predictor using a multi-modal vision-language model.
@@ -127,12 +143,37 @@ def _get_prompt(self, label: str) -> str:
127143
128144 return prompt
129145
146+ def _strip (self , text : str ):
147+ """
148+ Removes any occurrences of the substrings in remove_list from the end of text.
149+
150+ Parameters
151+ ----------
152+ text : str
153+ The original string.
154+
155+ Returns
156+ -------
157+ str
158+ The trimmed string.
159+ """
160+ remove_list = [r"\quad" , r"\\" , r"\," , " c c c c" , " l l l l l" ]
161+ changed = True
162+ while changed :
163+ changed = False
164+ for substr in remove_list :
165+ if text .endswith (substr ):
166+ text = text [: - len (substr )]
167+ changed = True
168+
169+ return text .strip ()
170+
130171 @torch .inference_mode ()
131172 def predict (
132173 self ,
133174 images : List [Union [Image .Image , np .ndarray ]],
134175 labels : List [str ],
135- temperature : Optional [float ] = 0.1 ,
176+ temperature : Optional [float ] = 0.0 ,
136177 ) -> List [str ]:
137178 """
138179 Predicts the textual representation of input images (code or LaTeX).
@@ -144,7 +185,7 @@ def predict(
144185 labels : List[str]
145186 List of labels indicating the type of each image ('code' or 'formula').
146187 temperature : Optional[float]
147- Sampling temperature for generation, by default set to 0.1 .
188+ Sampling temperature for generation, by default set to 0.0 .
148189
149190 Returns
150191 -------
@@ -198,6 +239,16 @@ def predict(
198239 prompt_ids = tokenized ["input_ids" ]
199240 attention_mask = tokenized ["attention_mask" ]
200241
242+ stopping_criteria = StoppingCriteriaList (
243+ [
244+ StopOnString (self ._tokenizer , r" \quad \quad \quad \quad" ),
245+ StopOnString (self ._tokenizer , r" \\ \\ \\ \\" ),
246+ StopOnString (self ._tokenizer , r" \, \, \, \," ),
247+ StopOnString (self ._tokenizer , r" c c c c c c c c c c c c c c c c" ),
248+ StopOnString (self ._tokenizer , r" l l l l l l l l l l l l l l l l l" ),
249+ ]
250+ )
251+
201252 if self ._device == "cpu" :
202253 output_ids_list = self ._model .generate (
203254 input_ids = prompt_ids ,
@@ -207,7 +258,8 @@ def predict(
207258 temperature = temperature ,
208259 max_new_tokens = 4096 - prompt_ids .shape [1 ],
209260 use_cache = True ,
210- no_repeat_ngram_size = 300 ,
261+ no_repeat_ngram_size = 200 ,
262+ stopping_criteria = stopping_criteria ,
211263 )
212264 else :
213265 with torch .autocast (device_type = self ._device , dtype = torch .bfloat16 ):
@@ -218,11 +270,13 @@ def predict(
218270 temperature = temperature ,
219271 max_new_tokens = 4096 - prompt_ids .shape [1 ],
220272 use_cache = True ,
221- no_repeat_ngram_size = 300 ,
273+ no_repeat_ngram_size = 200 ,
274+ stopping_criteria = stopping_criteria ,
222275 )
223276
224277 outputs = self ._tokenizer .batch_decode (
225278 output_ids_list [:, prompt_ids .shape [1 ] :], skip_special_tokens = True
226279 )
280+ outputs = [self ._strip (output ) for output in outputs ]
227281
228282 return outputs
0 commit comments