Skip to content

Commit e2d30b9

Browse files
authored
Merge pull request #3668 from flairNLP/GH-3655-tokenization-on-predict
GH-3655: Tokenization on predict
2 parents ee8596c + d35c319 commit e2d30b9

File tree

10 files changed

+1061
-183
lines changed

10 files changed

+1061
-183
lines changed

flair/data.py

Lines changed: 281 additions & 164 deletions
Large diffs are not rendered by default.

flair/embeddings/transformer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import abstractmethod
99
from io import BytesIO
1010
from pathlib import Path
11-
from typing import Any, Literal, Optional, Union, cast
11+
from typing import Any, Callable, Dict, Literal, Optional, Union, cast
1212

1313
import torch
1414
import transformers
@@ -26,6 +26,7 @@
2626
LayoutLMv2FeatureExtractor,
2727
PretrainedConfig,
2828
PreTrainedTokenizer,
29+
T5Config,
2930
T5TokenizerFast,
3031
)
3132
from transformers.tokenization_utils_base import LARGE_INTEGER
@@ -674,7 +675,9 @@ def __build_transformer_model_inputs(
674675

675676
if self.feature_extractor is not None:
676677
images = [sent.get_metadata("image") for sent in sentences]
677-
image_encodings = self.feature_extractor(images, return_tensors="pt")["pixel_values"]
678+
# Cast self.feature_extractor to a callable type
679+
feature_extractor_callable = cast(Callable[..., Dict[str, Any]], self.feature_extractor)
680+
image_encodings = feature_extractor_callable(images, return_tensors="pt")["pixel_values"]
678681
if cpu_overflow_to_sample_mapping is not None:
679682
batched_image_encodings = [image_encodings[i] for i in cpu_overflow_to_sample_mapping]
680683
image_encodings = torch.stack(batched_image_encodings)
@@ -1138,7 +1141,8 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
11381141
if is_supported_t5_model(saved_config):
11391142
from transformers import T5EncoderModel
11401143

1141-
transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs, **kwargs)
1144+
# Cast saved_config to T5Config
1145+
transformer_model = T5EncoderModel(cast(T5Config, saved_config), **transformers_model_kwargs, **kwargs)
11421146
else:
11431147
transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs, **kwargs)
11441148
try:

flair/models/sequence_tagger_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,12 @@ def predict(
489489
# filter empty sentences
490490
sentences = [sentence for sentence in sentences if len(sentence) > 0]
491491

492+
# Use the tokenizer property getter
493+
model_tokenizer = self.tokenizer
494+
if model_tokenizer is not None:
495+
for sentence in sentences:
496+
sentence.tokenizer = model_tokenizer
497+
492498
# reverse sort all sequences by their length
493499
reordered_sentences = sorted(sentences, key=len, reverse=True)
494500

flair/models/tars_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tqdm import tqdm
1313

1414
import flair
15-
from flair.data import Corpus, Dictionary, Sentence, Span
15+
from flair.data import Corpus, Dictionary, Sentence, Span, Token
1616
from flair.datasets import DataLoader, FlairDatapointDataset
1717
from flair.embeddings import (
1818
TokenEmbeddings,
@@ -594,11 +594,15 @@ def predict(
594594
continue
595595

596596
# only add if all tokens have no label
597-
if tag_this:
597+
if tag_this and isinstance(label.data_point, Span):
598+
# get tokens and filter None (they don't exist, but we have to make this explicit for mypy)
599+
token_list = [
600+
sentence.get_token(token.idx - label_length) for token in label.data_point
601+
]
602+
token_list_no_none = [token for token in token_list if isinstance(token, Token)]
603+
598604
# make and add a corresponding predicted span
599-
predicted_span = Span(
600-
[sentence.get_token(token.idx - label_length) for token in label.data_point]
601-
)
605+
predicted_span = Span(token_list_no_none)
602606
predicted_span.add_label(label_name, value=label.value, score=label.score)
603607

604608
# set indices so that no token can be tagged twice

flair/nn/model.py

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tqdm import tqdm
1515

1616
import flair
17+
import flair.tokenization
1718
from flair.class_utils import get_non_abstract_subclasses
1819
from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset
1920
from flair.datasets import DataLoader, FlairDatapointDataset
@@ -22,6 +23,7 @@
2223
from flair.embeddings.base import load_embeddings
2324
from flair.file_utils import Tqdm, load_torch_state
2425
from flair.training_utils import EmbeddingStorageMode, Result, store_embeddings
26+
import importlib
2527

2628
log = logging.getLogger("flair")
2729

@@ -44,6 +46,32 @@ def __init__(self) -> None:
4446
self.optimizer_state_dict: Optional[dict[str, Any]] = None
4547
self.scheduler_state_dict: Optional[dict[str, Any]] = None
4648

49+
# Internal storage for the tokenizer
50+
self._tokenizer: Optional[flair.tokenization.Tokenizer] = None
51+
52+
@property
53+
def tokenizer(self) -> Optional[flair.tokenization.Tokenizer]:
54+
"""
55+
Gets the tokenizer associated with this model.
56+
Returns:
57+
Optional[flair.tokenization.Tokenizer]: The tokenizer instance, or None if not set.
58+
"""
59+
return self._tokenizer
60+
61+
@tokenizer.setter
62+
def tokenizer(self, value: Optional[flair.tokenization.Tokenizer]) -> None:
63+
"""
64+
Sets the tokenizer for this model.
65+
Args:
66+
value (Optional[flair.tokenization.Tokenizer]): The tokenizer instance to set.
67+
"""
68+
if self._tokenizer is not value: # Basic check to avoid unnecessary logging if same instance
69+
log.debug(
70+
f"Model tokenizer changed from {self._tokenizer.__class__.__name__ if self._tokenizer else 'None'} "
71+
f"to {value.__class__.__name__ if value else 'None'}"
72+
)
73+
self._tokenizer = value
74+
4775
@property
4876
@abstractmethod
4977
def label_type(self) -> str:
@@ -104,19 +132,63 @@ def _get_state_dict(self) -> dict:
104132
- "optimizer_state_dict": The optimizer's state dictionary (if it exists)
105133
- "scheduler_state_dict": The scheduler's state dictionary (if it exists)
106134
- "model_card": Training parameters and metadata (if set)
135+
- "tokenizer_info": Information to reconstruct the tokenizer used during training (if any and serializable)
107136
"""
108137
# Always include the name of the Model class for which the state dict holds
109-
state_dict = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__}
138+
state = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__}
110139

111140
# Add optimizer state dict if it exists
112141
if hasattr(self, "optimizer_state_dict") and self.optimizer_state_dict is not None:
113-
state_dict["optimizer_state_dict"] = self.optimizer_state_dict
142+
state["optimizer_state_dict"] = self.optimizer_state_dict
114143

115144
# Add scheduler state dict if it exists
116145
if hasattr(self, "scheduler_state_dict") and self.scheduler_state_dict is not None:
117-
state_dict["scheduler_state_dict"] = self.scheduler_state_dict
146+
state["scheduler_state_dict"] = self.scheduler_state_dict
147+
148+
# -- Start Tokenizer Serialization Logic --
149+
tokenizer_info = None # Default: no tokenizer info saved
150+
151+
# Get the tokenizer
152+
current_tokenizer = self.tokenizer
153+
154+
if current_tokenizer is not None:
155+
156+
if hasattr(current_tokenizer, "to_dict") and callable(getattr(current_tokenizer, "to_dict")):
157+
try:
158+
potential_tokenizer_info = current_tokenizer.to_dict()
159+
160+
if (
161+
isinstance(potential_tokenizer_info, dict)
162+
and "class_module" in potential_tokenizer_info
163+
and "class_name" in potential_tokenizer_info
164+
):
165+
tokenizer_info = potential_tokenizer_info # Store the valid dict
166+
else:
167+
log.warning(
168+
f"Tokenizer {current_tokenizer.__class__.__name__} has a 'to_dict' method, "
169+
f"but it did not return a valid dictionary with 'class_module' and "
170+
f"'class_name'. Tokenizer will not be saved automatically."
171+
)
172+
# tokenizer_info remains None
173+
except Exception as e:
174+
log.warning(
175+
f"Error calling 'to_dict' on tokenizer {current_tokenizer.__class__.__name__}: {e}. "
176+
f"Tokenizer will not be saved automatically."
177+
)
178+
# tokenizer_info remains None
179+
else:
180+
log.warning(
181+
f"Tokenizer {current_tokenizer.__class__.__name__} does not implement the 'to_dict' method "
182+
f"required for automatic saving. It will not be saved automatically. "
183+
f"You may need to manually attach it after loading the model."
184+
)
185+
# tokenizer_info remains None
118186

119-
return state_dict
187+
# Add the determined tokenizer_info (either dict or None) to the state
188+
state["tokenizer_info"] = tokenizer_info # type: ignore[assignment]
189+
# -- End Tokenizer Serialization Logic --
190+
191+
return state
120192

121193
@classmethod
122194
def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
@@ -128,7 +200,6 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
128200
kwargs["embeddings"] = embeddings
129201

130202
model = cls(**kwargs)
131-
132203
model.load_state_dict(state["state_dict"])
133204

134205
# load optimizer state if it exists in the state dict
@@ -141,7 +212,49 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
141212
log.debug(f"Found scheduler state in model file with keys: {state['scheduler_state_dict'].keys()}")
142213
model.scheduler_state_dict = state["scheduler_state_dict"]
143214

144-
return model
215+
# --- Part 3: Load Tokenizer ---
216+
tokenizer_instance = None # Default to None
217+
if "tokenizer_info" in state and state["tokenizer_info"] is not None:
218+
tokenizer_info = state["tokenizer_info"]
219+
if isinstance(tokenizer_info, dict) and "class_module" in tokenizer_info and "class_name" in tokenizer_info:
220+
module_name = tokenizer_info["class_module"]
221+
class_name = tokenizer_info["class_name"]
222+
try:
223+
# ... (importlib logic, call from_dict, error handling) ...
224+
module = importlib.import_module(module_name)
225+
TokenizerClass = getattr(module, class_name)
226+
if hasattr(TokenizerClass, "from_dict") and callable(getattr(TokenizerClass, "from_dict")):
227+
tokenizer_instance = TokenizerClass.from_dict(tokenizer_info)
228+
log.info(f"Successfully loaded tokenizer '{class_name}' from '{module_name}'.")
229+
else:
230+
log.warning(
231+
f"Tokenizer class '{class_name}' found in '{module_name}', but it is missing the required "
232+
f"'from_dict' class method. Tokenizer cannot be loaded automatically."
233+
)
234+
except ImportError:
235+
log.warning(
236+
f"Could not import tokenizer module '{module_name}'. "
237+
f"Make sure the module containing '{class_name}' is installed and accessible."
238+
)
239+
except AttributeError:
240+
log.warning(f"Could not find tokenizer class '{class_name}' in module '{module_name}'.")
241+
except Exception as e:
242+
log.warning(
243+
f"Error reconstructing tokenizer '{class_name}' from module '{module_name}' "
244+
f"using 'from_dict': {e}"
245+
)
246+
247+
else:
248+
log.warning(
249+
"Found 'tokenizer_info' in saved model state, but it is invalid (must be a dict with "
250+
"'class_module' and 'class_name'). Tokenizer cannot be loaded automatically."
251+
)
252+
253+
# Assign the result (instance or None) to the model
254+
model._tokenizer = tokenizer_instance
255+
# --- End Tokenizer Loading ---
256+
257+
return model # Return the initialized model object
145258

146259
@staticmethod
147260
def _fetch_model(model_identifier: str):
@@ -933,6 +1046,14 @@ def predict(
9331046
if isinstance(sentences[0], Sentence):
9341047
Sentence.set_context_for_sentences(typing.cast(list[Sentence], sentences))
9351048

1049+
# Use the tokenizer property getter
1050+
model_tokenizer = self.tokenizer
1051+
if model_tokenizer is not None:
1052+
for sentence in sentences:
1053+
# this affects only models that call predict over Sentence or EncodedSentence objects (not Spans, etc.)
1054+
if isinstance(sentence, Sentence):
1055+
sentence.tokenizer = model_tokenizer
1056+
9361057
reordered_sentences = self._sort_data(sentences)
9371058

9381059
if len(reordered_sentences) == 0:
@@ -975,8 +1096,8 @@ def predict(
9751096
# if anything could possibly be predicted
9761097
if data_points:
9771098
# remove previously predicted labels of this type
978-
for sentence in data_points:
979-
sentence.remove_labels(label_name)
1099+
for data_point in data_points:
1100+
data_point.remove_labels(label_name)
9801101

9811102
if return_loss:
9821103
# filter data points that have labels outside of dictionary

0 commit comments

Comments
 (0)