Skip to content

Commit b78d81f

Browse files
fix: performance issue code formula model (#78)
Signed-off-by: Matteo Omenetti <[email protected]>
1 parent d6a3549 commit b78d81f

File tree

1 file changed

+59
-5
lines changed

1 file changed

+59
-5
lines changed

docling_ibm_models/code_formula_model/code_formula_predictor.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import torch
1010
from PIL import Image
11-
from transformers import AutoTokenizer
11+
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
1212

1313
from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM
1414
from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import (
@@ -18,6 +18,22 @@
1818
_log = logging.getLogger(__name__)
1919

2020

21+
class StopOnString(StoppingCriteria):
22+
def __init__(self, tokenizer, stop_string):
23+
self.stop_token_ids = tokenizer.encode(stop_string, add_special_tokens=False)
24+
25+
def __call__(self, input_ids, scores, **kwargs):
26+
for sequence in input_ids:
27+
sequence_list = sequence.tolist()
28+
for i in range(len(sequence_list) - len(self.stop_token_ids) + 1):
29+
if (
30+
sequence_list[i : i + len(self.stop_token_ids)]
31+
== self.stop_token_ids
32+
):
33+
return True
34+
return False
35+
36+
2137
class CodeFormulaPredictor:
2238
"""
2339
Code and Formula Predictor using a multi-modal vision-language model.
@@ -127,12 +143,37 @@ def _get_prompt(self, label: str) -> str:
127143

128144
return prompt
129145

146+
def _strip(self, text: str):
147+
"""
148+
Removes any occurrences of the substrings in remove_list from the end of text.
149+
150+
Parameters
151+
----------
152+
text : str
153+
The original string.
154+
155+
Returns
156+
-------
157+
str
158+
The trimmed string.
159+
"""
160+
remove_list = [r"\quad", r"\\", r"\,", " c c c c", " l l l l l"]
161+
changed = True
162+
while changed:
163+
changed = False
164+
for substr in remove_list:
165+
if text.endswith(substr):
166+
text = text[: -len(substr)]
167+
changed = True
168+
169+
return text.strip()
170+
130171
@torch.inference_mode()
131172
def predict(
132173
self,
133174
images: List[Union[Image.Image, np.ndarray]],
134175
labels: List[str],
135-
temperature: Optional[float] = 0.1,
176+
temperature: Optional[float] = 0.0,
136177
) -> List[str]:
137178
"""
138179
Predicts the textual representation of input images (code or LaTeX).
@@ -144,7 +185,7 @@ def predict(
144185
labels : List[str]
145186
List of labels indicating the type of each image ('code' or 'formula').
146187
temperature : Optional[float]
147-
Sampling temperature for generation, by default set to 0.1.
188+
Sampling temperature for generation, by default set to 0.0.
148189
149190
Returns
150191
-------
@@ -198,6 +239,16 @@ def predict(
198239
prompt_ids = tokenized["input_ids"]
199240
attention_mask = tokenized["attention_mask"]
200241

242+
stopping_criteria = StoppingCriteriaList(
243+
[
244+
StopOnString(self._tokenizer, r" \quad \quad \quad \quad"),
245+
StopOnString(self._tokenizer, r" \\ \\ \\ \\"),
246+
StopOnString(self._tokenizer, r" \, \, \, \,"),
247+
StopOnString(self._tokenizer, r" c c c c c c c c c c c c c c c c"),
248+
StopOnString(self._tokenizer, r" l l l l l l l l l l l l l l l l l"),
249+
]
250+
)
251+
201252
if self._device == "cpu":
202253
output_ids_list = self._model.generate(
203254
input_ids=prompt_ids,
@@ -207,7 +258,8 @@ def predict(
207258
temperature=temperature,
208259
max_new_tokens=4096 - prompt_ids.shape[1],
209260
use_cache=True,
210-
no_repeat_ngram_size=300,
261+
no_repeat_ngram_size=200,
262+
stopping_criteria=stopping_criteria,
211263
)
212264
else:
213265
with torch.autocast(device_type=self._device, dtype=torch.bfloat16):
@@ -218,11 +270,13 @@ def predict(
218270
temperature=temperature,
219271
max_new_tokens=4096 - prompt_ids.shape[1],
220272
use_cache=True,
221-
no_repeat_ngram_size=300,
273+
no_repeat_ngram_size=200,
274+
stopping_criteria=stopping_criteria,
222275
)
223276

224277
outputs = self._tokenizer.batch_decode(
225278
output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True
226279
)
280+
outputs = [self._strip(output) for output in outputs]
227281

228282
return outputs

0 commit comments

Comments
 (0)