Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
445 changes: 281 additions & 164 deletions flair/data.py

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Any, Literal, Optional, Union, cast
from typing import Any, Callable, Dict, Literal, Optional, Union, cast

import torch
import transformers
Expand All @@ -26,6 +26,7 @@
LayoutLMv2FeatureExtractor,
PretrainedConfig,
PreTrainedTokenizer,
T5Config,
T5TokenizerFast,
)
from transformers.tokenization_utils_base import LARGE_INTEGER
Expand Down Expand Up @@ -674,7 +675,9 @@ def __build_transformer_model_inputs(

if self.feature_extractor is not None:
images = [sent.get_metadata("image") for sent in sentences]
image_encodings = self.feature_extractor(images, return_tensors="pt")["pixel_values"]
# Cast self.feature_extractor to a callable type
feature_extractor_callable = cast(Callable[..., Dict[str, Any]], self.feature_extractor)
image_encodings = feature_extractor_callable(images, return_tensors="pt")["pixel_values"]
if cpu_overflow_to_sample_mapping is not None:
batched_image_encodings = [image_encodings[i] for i in cpu_overflow_to_sample_mapping]
image_encodings = torch.stack(batched_image_encodings)
Expand Down Expand Up @@ -1138,7 +1141,8 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if is_supported_t5_model(saved_config):
from transformers import T5EncoderModel

transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs, **kwargs)
# Cast saved_config to T5Config
transformer_model = T5EncoderModel(cast(T5Config, saved_config), **transformers_model_kwargs, **kwargs)
else:
transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs, **kwargs)
try:
Expand Down
6 changes: 6 additions & 0 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ def predict(
# filter empty sentences
sentences = [sentence for sentence in sentences if len(sentence) > 0]

# Use the tokenizer property getter
model_tokenizer = self.tokenizer
if model_tokenizer is not None:
for sentence in sentences:
sentence.tokenizer = model_tokenizer

# reverse sort all sequences by their length
reordered_sentences = sorted(sentences, key=len, reverse=True)

Expand Down
14 changes: 9 additions & 5 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

import flair
from flair.data import Corpus, Dictionary, Sentence, Span
from flair.data import Corpus, Dictionary, Sentence, Span, Token
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import (
TokenEmbeddings,
Expand Down Expand Up @@ -594,11 +594,15 @@ def predict(
continue

# only add if all tokens have no label
if tag_this:
if tag_this and isinstance(label.data_point, Span):
# get tokens and filter None (they don't exist, but we have to make this explicit for mypy)
token_list = [
sentence.get_token(token.idx - label_length) for token in label.data_point
]
token_list_no_none = [token for token in token_list if isinstance(token, Token)]

# make and add a corresponding predicted span
predicted_span = Span(
[sentence.get_token(token.idx - label_length) for token in label.data_point]
)
predicted_span = Span(token_list_no_none)
predicted_span.add_label(label_name, value=label.value, score=label.score)

# set indices so that no token can be tagged twice
Expand Down
137 changes: 129 additions & 8 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tqdm import tqdm

import flair
import flair.tokenization
from flair.class_utils import get_non_abstract_subclasses
from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
Expand All @@ -22,6 +23,7 @@
from flair.embeddings.base import load_embeddings
from flair.file_utils import Tqdm, load_torch_state
from flair.training_utils import EmbeddingStorageMode, Result, store_embeddings
import importlib

log = logging.getLogger("flair")

Expand All @@ -44,6 +46,32 @@ def __init__(self) -> None:
self.optimizer_state_dict: Optional[dict[str, Any]] = None
self.scheduler_state_dict: Optional[dict[str, Any]] = None

# Internal storage for the tokenizer
self._tokenizer: Optional[flair.tokenization.Tokenizer] = None

@property
def tokenizer(self) -> Optional[flair.tokenization.Tokenizer]:
"""
Gets the tokenizer associated with this model.
Returns:
Optional[flair.tokenization.Tokenizer]: The tokenizer instance, or None if not set.
"""
return self._tokenizer

@tokenizer.setter
def tokenizer(self, value: Optional[flair.tokenization.Tokenizer]) -> None:
"""
Sets the tokenizer for this model.
Args:
value (Optional[flair.tokenization.Tokenizer]): The tokenizer instance to set.
"""
if self._tokenizer is not value: # Basic check to avoid unnecessary logging if same instance
log.debug(
f"Model tokenizer changed from {self._tokenizer.__class__.__name__ if self._tokenizer else 'None'} "
f"to {value.__class__.__name__ if value else 'None'}"
)
self._tokenizer = value

@property
@abstractmethod
def label_type(self) -> str:
Expand Down Expand Up @@ -104,19 +132,63 @@ def _get_state_dict(self) -> dict:
- "optimizer_state_dict": The optimizer's state dictionary (if it exists)
- "scheduler_state_dict": The scheduler's state dictionary (if it exists)
- "model_card": Training parameters and metadata (if set)
- "tokenizer_info": Information to reconstruct the tokenizer used during training (if any and serializable)
"""
# Always include the name of the Model class for which the state dict holds
state_dict = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__}
state = {"state_dict": self.state_dict(), "__cls__": self.__class__.__name__}

# Add optimizer state dict if it exists
if hasattr(self, "optimizer_state_dict") and self.optimizer_state_dict is not None:
state_dict["optimizer_state_dict"] = self.optimizer_state_dict
state["optimizer_state_dict"] = self.optimizer_state_dict

# Add scheduler state dict if it exists
if hasattr(self, "scheduler_state_dict") and self.scheduler_state_dict is not None:
state_dict["scheduler_state_dict"] = self.scheduler_state_dict
state["scheduler_state_dict"] = self.scheduler_state_dict

# -- Start Tokenizer Serialization Logic --
tokenizer_info = None # Default: no tokenizer info saved

# Get the tokenizer
current_tokenizer = self.tokenizer

if current_tokenizer is not None:

if hasattr(current_tokenizer, "to_dict") and callable(getattr(current_tokenizer, "to_dict")):
try:
potential_tokenizer_info = current_tokenizer.to_dict()

if (
isinstance(potential_tokenizer_info, dict)
and "class_module" in potential_tokenizer_info
and "class_name" in potential_tokenizer_info
):
tokenizer_info = potential_tokenizer_info # Store the valid dict
else:
log.warning(
f"Tokenizer {current_tokenizer.__class__.__name__} has a 'to_dict' method, "
f"but it did not return a valid dictionary with 'class_module' and "
f"'class_name'. Tokenizer will not be saved automatically."
)
# tokenizer_info remains None
except Exception as e:
log.warning(
f"Error calling 'to_dict' on tokenizer {current_tokenizer.__class__.__name__}: {e}. "
f"Tokenizer will not be saved automatically."
)
# tokenizer_info remains None
else:
log.warning(
f"Tokenizer {current_tokenizer.__class__.__name__} does not implement the 'to_dict' method "
f"required for automatic saving. It will not be saved automatically. "
f"You may need to manually attach it after loading the model."
)
# tokenizer_info remains None

return state_dict
# Add the determined tokenizer_info (either dict or None) to the state
state["tokenizer_info"] = tokenizer_info # type: ignore[assignment]
# -- End Tokenizer Serialization Logic --

return state

@classmethod
def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
Expand All @@ -128,7 +200,6 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
kwargs["embeddings"] = embeddings

model = cls(**kwargs)

model.load_state_dict(state["state_dict"])

# load optimizer state if it exists in the state dict
Expand All @@ -141,7 +212,49 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
log.debug(f"Found scheduler state in model file with keys: {state['scheduler_state_dict'].keys()}")
model.scheduler_state_dict = state["scheduler_state_dict"]

return model
# --- Part 3: Load Tokenizer ---
tokenizer_instance = None # Default to None
if "tokenizer_info" in state and state["tokenizer_info"] is not None:
tokenizer_info = state["tokenizer_info"]
if isinstance(tokenizer_info, dict) and "class_module" in tokenizer_info and "class_name" in tokenizer_info:
module_name = tokenizer_info["class_module"]
class_name = tokenizer_info["class_name"]
try:
# ... (importlib logic, call from_dict, error handling) ...
module = importlib.import_module(module_name)
TokenizerClass = getattr(module, class_name)
if hasattr(TokenizerClass, "from_dict") and callable(getattr(TokenizerClass, "from_dict")):
tokenizer_instance = TokenizerClass.from_dict(tokenizer_info)
log.info(f"Successfully loaded tokenizer '{class_name}' from '{module_name}'.")
else:
log.warning(
f"Tokenizer class '{class_name}' found in '{module_name}', but it is missing the required "
f"'from_dict' class method. Tokenizer cannot be loaded automatically."
)
except ImportError:
log.warning(
f"Could not import tokenizer module '{module_name}'. "
f"Make sure the module containing '{class_name}' is installed and accessible."
)
except AttributeError:
log.warning(f"Could not find tokenizer class '{class_name}' in module '{module_name}'.")
except Exception as e:
log.warning(
f"Error reconstructing tokenizer '{class_name}' from module '{module_name}' "
f"using 'from_dict': {e}"
)

else:
log.warning(
"Found 'tokenizer_info' in saved model state, but it is invalid (must be a dict with "
"'class_module' and 'class_name'). Tokenizer cannot be loaded automatically."
)

# Assign the result (instance or None) to the model
model._tokenizer = tokenizer_instance
# --- End Tokenizer Loading ---

return model # Return the initialized model object

@staticmethod
def _fetch_model(model_identifier: str):
Expand Down Expand Up @@ -933,6 +1046,14 @@ def predict(
if isinstance(sentences[0], Sentence):
Sentence.set_context_for_sentences(typing.cast(list[Sentence], sentences))

# Use the tokenizer property getter
model_tokenizer = self.tokenizer
if model_tokenizer is not None:
for sentence in sentences:
# this affects only models that call predict over Sentence or EncodedSentence objects (not Spans, etc.)
if isinstance(sentence, Sentence):
sentence.tokenizer = model_tokenizer

reordered_sentences = self._sort_data(sentences)

if len(reordered_sentences) == 0:
Expand Down Expand Up @@ -975,8 +1096,8 @@ def predict(
# if anything could possibly be predicted
if data_points:
# remove previously predicted labels of this type
for sentence in data_points:
sentence.remove_labels(label_name)
for data_point in data_points:
data_point.remove_labels(label_name)

if return_loss:
# filter data points that have labels outside of dictionary
Expand Down
Loading
Loading