1414from tqdm import tqdm
1515
1616import flair
17+ import flair .tokenization
1718from flair .class_utils import get_non_abstract_subclasses
1819from flair .data import DT , DT2 , Corpus , Dictionary , Sentence , _iter_dataset
1920from flair .datasets import DataLoader , FlairDatapointDataset
2223from flair .embeddings .base import load_embeddings
2324from flair .file_utils import Tqdm , load_torch_state
2425from flair .training_utils import EmbeddingStorageMode , Result , store_embeddings
26+ import importlib
2527
2628log = logging .getLogger ("flair" )
2729
@@ -44,6 +46,32 @@ def __init__(self) -> None:
4446 self .optimizer_state_dict : Optional [dict [str , Any ]] = None
4547 self .scheduler_state_dict : Optional [dict [str , Any ]] = None
4648
49+ # Internal storage for the tokenizer
50+ self ._tokenizer : Optional [flair .tokenization .Tokenizer ] = None
51+
52+ @property
53+ def tokenizer (self ) -> Optional [flair .tokenization .Tokenizer ]:
54+ """
55+ Gets the tokenizer associated with this model.
56+ Returns:
57+ Optional[flair.tokenization.Tokenizer]: The tokenizer instance, or None if not set.
58+ """
59+ return self ._tokenizer
60+
61+ @tokenizer .setter
62+ def tokenizer (self , value : Optional [flair .tokenization .Tokenizer ]) -> None :
63+ """
64+ Sets the tokenizer for this model.
65+ Args:
66+ value (Optional[flair.tokenization.Tokenizer]): The tokenizer instance to set.
67+ """
68+ if self ._tokenizer is not value : # Basic check to avoid unnecessary logging if same instance
69+ log .debug (
70+ f"Model tokenizer changed from { self ._tokenizer .__class__ .__name__ if self ._tokenizer else 'None' } "
71+ f"to { value .__class__ .__name__ if value else 'None' } "
72+ )
73+ self ._tokenizer = value
74+
4775 @property
4876 @abstractmethod
4977 def label_type (self ) -> str :
@@ -104,19 +132,63 @@ def _get_state_dict(self) -> dict:
104132 - "optimizer_state_dict": The optimizer's state dictionary (if it exists)
105133 - "scheduler_state_dict": The scheduler's state dictionary (if it exists)
106134 - "model_card": Training parameters and metadata (if set)
135+ - "tokenizer_info": Information to reconstruct the tokenizer used during training (if any and serializable)
107136 """
108137 # Always include the name of the Model class for which the state dict holds
109- state_dict = {"state_dict" : self .state_dict (), "__cls__" : self .__class__ .__name__ }
138+ state = {"state_dict" : self .state_dict (), "__cls__" : self .__class__ .__name__ }
110139
111140 # Add optimizer state dict if it exists
112141 if hasattr (self , "optimizer_state_dict" ) and self .optimizer_state_dict is not None :
113- state_dict ["optimizer_state_dict" ] = self .optimizer_state_dict
142+ state ["optimizer_state_dict" ] = self .optimizer_state_dict
114143
115144 # Add scheduler state dict if it exists
116145 if hasattr (self , "scheduler_state_dict" ) and self .scheduler_state_dict is not None :
117- state_dict ["scheduler_state_dict" ] = self .scheduler_state_dict
146+ state ["scheduler_state_dict" ] = self .scheduler_state_dict
147+
148+ # -- Start Tokenizer Serialization Logic --
149+ tokenizer_info = None # Default: no tokenizer info saved
150+
151+ # Get the tokenizer
152+ current_tokenizer = self .tokenizer
153+
154+ if current_tokenizer is not None :
155+
156+ if hasattr (current_tokenizer , "to_dict" ) and callable (getattr (current_tokenizer , "to_dict" )):
157+ try :
158+ potential_tokenizer_info = current_tokenizer .to_dict ()
159+
160+ if (
161+ isinstance (potential_tokenizer_info , dict )
162+ and "class_module" in potential_tokenizer_info
163+ and "class_name" in potential_tokenizer_info
164+ ):
165+ tokenizer_info = potential_tokenizer_info # Store the valid dict
166+ else :
167+ log .warning (
168+ f"Tokenizer { current_tokenizer .__class__ .__name__ } has a 'to_dict' method, "
169+ f"but it did not return a valid dictionary with 'class_module' and "
170+ f"'class_name'. Tokenizer will not be saved automatically."
171+ )
172+ # tokenizer_info remains None
173+ except Exception as e :
174+ log .warning (
175+ f"Error calling 'to_dict' on tokenizer { current_tokenizer .__class__ .__name__ } : { e } . "
176+ f"Tokenizer will not be saved automatically."
177+ )
178+ # tokenizer_info remains None
179+ else :
180+ log .warning (
181+ f"Tokenizer { current_tokenizer .__class__ .__name__ } does not implement the 'to_dict' method "
182+ f"required for automatic saving. It will not be saved automatically. "
183+ f"You may need to manually attach it after loading the model."
184+ )
185+ # tokenizer_info remains None
118186
119- return state_dict
187+ # Add the determined tokenizer_info (either dict or None) to the state
188+ state ["tokenizer_info" ] = tokenizer_info # type: ignore[assignment]
189+ # -- End Tokenizer Serialization Logic --
190+
191+ return state
120192
121193 @classmethod
122194 def _init_model_with_state_dict (cls , state : dict [str , Any ], ** kwargs ):
@@ -128,7 +200,6 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
128200 kwargs ["embeddings" ] = embeddings
129201
130202 model = cls (** kwargs )
131-
132203 model .load_state_dict (state ["state_dict" ])
133204
134205 # load optimizer state if it exists in the state dict
@@ -141,7 +212,49 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
141212 log .debug (f"Found scheduler state in model file with keys: { state ['scheduler_state_dict' ].keys ()} " )
142213 model .scheduler_state_dict = state ["scheduler_state_dict" ]
143214
144- return model
215+ # --- Part 3: Load Tokenizer ---
216+ tokenizer_instance = None # Default to None
217+ if "tokenizer_info" in state and state ["tokenizer_info" ] is not None :
218+ tokenizer_info = state ["tokenizer_info" ]
219+ if isinstance (tokenizer_info , dict ) and "class_module" in tokenizer_info and "class_name" in tokenizer_info :
220+ module_name = tokenizer_info ["class_module" ]
221+ class_name = tokenizer_info ["class_name" ]
222+ try :
223+ # ... (importlib logic, call from_dict, error handling) ...
224+ module = importlib .import_module (module_name )
225+ TokenizerClass = getattr (module , class_name )
226+ if hasattr (TokenizerClass , "from_dict" ) and callable (getattr (TokenizerClass , "from_dict" )):
227+ tokenizer_instance = TokenizerClass .from_dict (tokenizer_info )
228+ log .info (f"Successfully loaded tokenizer '{ class_name } ' from '{ module_name } '." )
229+ else :
230+ log .warning (
231+ f"Tokenizer class '{ class_name } ' found in '{ module_name } ', but it is missing the required "
232+ f"'from_dict' class method. Tokenizer cannot be loaded automatically."
233+ )
234+ except ImportError :
235+ log .warning (
236+ f"Could not import tokenizer module '{ module_name } '. "
237+ f"Make sure the module containing '{ class_name } ' is installed and accessible."
238+ )
239+ except AttributeError :
240+ log .warning (f"Could not find tokenizer class '{ class_name } ' in module '{ module_name } '." )
241+ except Exception as e :
242+ log .warning (
243+ f"Error reconstructing tokenizer '{ class_name } ' from module '{ module_name } ' "
244+ f"using 'from_dict': { e } "
245+ )
246+
247+ else :
248+ log .warning (
249+ "Found 'tokenizer_info' in saved model state, but it is invalid (must be a dict with "
250+ "'class_module' and 'class_name'). Tokenizer cannot be loaded automatically."
251+ )
252+
253+ # Assign the result (instance or None) to the model
254+ model ._tokenizer = tokenizer_instance
255+ # --- End Tokenizer Loading ---
256+
257+ return model # Return the initialized model object
145258
146259 @staticmethod
147260 def _fetch_model (model_identifier : str ):
@@ -933,6 +1046,14 @@ def predict(
9331046 if isinstance (sentences [0 ], Sentence ):
9341047 Sentence .set_context_for_sentences (typing .cast (list [Sentence ], sentences ))
9351048
1049+ # Use the tokenizer property getter
1050+ model_tokenizer = self .tokenizer
1051+ if model_tokenizer is not None :
1052+ for sentence in sentences :
1053+ # this affects only models that call predict over Sentence or EncodedSentence objects (not Spans, etc.)
1054+ if isinstance (sentence , Sentence ):
1055+ sentence .tokenizer = model_tokenizer
1056+
9361057 reordered_sentences = self ._sort_data (sentences )
9371058
9381059 if len (reordered_sentences ) == 0 :
@@ -975,8 +1096,8 @@ def predict(
9751096 # if anything could possibly be predicted
9761097 if data_points :
9771098 # remove previously predicted labels of this type
978- for sentence in data_points :
979- sentence .remove_labels (label_name )
1099+ for data_point in data_points :
1100+ data_point .remove_labels (label_name )
9801101
9811102 if return_loss :
9821103 # filter data points that have labels outside of dictionary
0 commit comments