11import logging
2+ import pathlib
23import random
34from collections import defaultdict
45from enum import Enum
56from functools import reduce
67from math import inf
78from pathlib import Path
8- from typing import Literal , Optional , Union
9+ from typing import Literal , NamedTuple , Optional , Union
910
11+ from numpy import ndarray
1012from scipy .stats import pearsonr , spearmanr
1113from sklearn .metrics import mean_absolute_error , mean_squared_error
1214from torch .optim import Optimizer
1315from torch .utils .data import Dataset
1416
1517import flair
16- from flair .data import DT , Dictionary , Sentence , _iter_dataset
18+ from flair .class_utils import StringLike
19+ from flair .data import DT , Dictionary , Sentence , Token , _iter_dataset
1720
1821EmbeddingStorageMode = Literal ["none" , "cpu" , "gpu" ]
19- log = logging .getLogger ("flair" )
22+ MinMax = Literal ["min" , "max" ]
23+ logger = logging .getLogger ("flair" )
2024
2125
2226class Result :
@@ -33,7 +37,7 @@ def __init__(
3337 self .main_score : float = main_score
3438 self .scores = scores
3539 self .detailed_results : str = detailed_results
36- self .classification_report = classification_report
40+ self .classification_report = classification_report if classification_report is not None else {}
3741
3842 @property
3943 def loss (self ):
@@ -44,13 +48,13 @@ def __str__(self) -> str:
4448
4549
4650class MetricRegression :
47- def __init__ (self , name ) -> None :
51+ def __init__ (self , name : str ) -> None :
4852 self .name = name
4953
5054 self .true : list [float ] = []
5155 self .pred : list [float ] = []
5256
53- def mean_squared_error (self ):
57+ def mean_squared_error (self ) -> Union [ float , ndarray ] :
5458 return mean_squared_error (self .true , self .pred )
5559
5660 def mean_absolute_error (self ):
@@ -62,22 +66,18 @@ def pearsonr(self):
6266 def spearmanr (self ):
6367 return spearmanr (self .true , self .pred )[0 ]
6468
65- # dummy return to fulfill trainer.train() needs
66- def micro_avg_f_score (self ):
67- return self .mean_squared_error ()
68-
69- def to_tsv (self ):
69+ def to_tsv (self ) -> str :
7070 return f"{ self .mean_squared_error ()} \t { self .mean_absolute_error ()} \t { self .pearsonr ()} \t { self .spearmanr ()} "
7171
7272 @staticmethod
73- def tsv_header (prefix = None ):
73+ def tsv_header (prefix : StringLike = None ) -> str :
7474 if prefix :
7575 return f"{ prefix } _MEAN_SQUARED_ERROR\t { prefix } _MEAN_ABSOLUTE_ERROR\t { prefix } _PEARSON\t { prefix } _SPEARMAN"
7676
7777 return "MEAN_SQUARED_ERROR\t MEAN_ABSOLUTE_ERROR\t PEARSON\t SPEARMAN"
7878
7979 @staticmethod
80- def to_empty_tsv ():
80+ def to_empty_tsv () -> str :
8181 return "\t _\t _\t _\t _"
8282
8383 def __str__ (self ) -> str :
@@ -101,13 +101,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
101101 self .weights_dict : dict [str , dict [int , list [float ]]] = defaultdict (lambda : defaultdict (list ))
102102 self .number_of_weights = number_of_weights
103103
104- def extract_weights (self , state_dict , iteration ) :
104+ def extract_weights (self , state_dict : dict , iteration : int ) -> None :
105105 for key in state_dict :
106106 vec = state_dict [key ]
107- # print(vec)
108107 try :
109108 weights_to_watch = min (self .number_of_weights , reduce (lambda x , y : x * y , list (vec .size ())))
110- except Exception :
109+ except Exception as e :
110+ logger .debug (e )
111111 continue
112112
113113 if key not in self .weights_dict :
@@ -195,15 +195,15 @@ class AnnealOnPlateau:
195195 def __init__ (
196196 self ,
197197 optimizer ,
198- mode = "min" ,
199- aux_mode = "min" ,
200- factor = 0.1 ,
201- patience = 10 ,
202- initial_extra_patience = 0 ,
203- verbose = False ,
204- cooldown = 0 ,
205- min_lr = 0 ,
206- eps = 1e-8 ,
198+ mode : MinMax = "min" ,
199+ aux_mode : MinMax = "min" ,
200+ factor : float = 0.1 ,
201+ patience : int = 10 ,
202+ initial_extra_patience : int = 0 ,
203+ verbose : bool = False ,
204+ cooldown : int = 0 ,
205+ min_lr : float = 0. 0 ,
206+ eps : float = 1e-8 ,
207207 ) -> None :
208208 if factor >= 1.0 :
209209 raise ValueError ("Factor should be < 1.0." )
@@ -214,6 +214,7 @@ def __init__(
214214 raise TypeError (f"{ type (optimizer ).__name__ } is not an Optimizer" )
215215 self .optimizer = optimizer
216216
217+ self .min_lrs : list [float ]
217218 if isinstance (min_lr , (list , tuple )):
218219 if len (min_lr ) != len (optimizer .param_groups ):
219220 raise ValueError (f"expected { len (optimizer .param_groups )} min_lrs, got { len (min_lr )} " )
@@ -231,7 +232,7 @@ def __init__(
231232 self .best = None
232233 self .best_aux = None
233234 self .num_bad_epochs = None
234- self .mode_worse = None # the worse value for the chosen mode
235+ self .mode_worse : Optional [ float ] = None # the worse value for the chosen mode
235236 self .eps = eps
236237 self .last_epoch = 0
237238 self ._init_is_better (mode = mode )
@@ -258,7 +259,7 @@ def step(self, metric, auxiliary_metric=None) -> bool:
258259 if self .mode == "max" and current > self .best :
259260 is_better = True
260261
261- if current == self .best and auxiliary_metric :
262+ if current == self .best and auxiliary_metric is not None :
262263 current_aux = float (auxiliary_metric )
263264 if self .aux_mode == "min" and current_aux < self .best_aux :
264265 is_better = True
@@ -289,20 +290,20 @@ def step(self, metric, auxiliary_metric=None) -> bool:
289290
290291 return reduce_learning_rate
291292
292- def _reduce_lr (self , epoch ) :
293+ def _reduce_lr (self , epoch : int ) -> None :
293294 for i , param_group in enumerate (self .optimizer .param_groups ):
294295 old_lr = float (param_group ["lr" ])
295296 new_lr = max (old_lr * self .factor , self .min_lrs [i ])
296297 if old_lr - new_lr > self .eps :
297298 param_group ["lr" ] = new_lr
298299 if self .verbose :
299- log .info (f" - reducing learning rate of group { epoch } to { new_lr } " )
300+ logger .info (f" - reducing learning rate of group { epoch } to { new_lr } " )
300301
301302 @property
302303 def in_cooldown (self ):
303304 return self .cooldown_counter > 0
304305
305- def _init_is_better (self , mode ) :
306+ def _init_is_better (self , mode : MinMax ) -> None :
306307 if mode not in {"min" , "max" }:
307308 raise ValueError ("mode " + mode + " is unknown!" )
308309
@@ -313,10 +314,10 @@ def _init_is_better(self, mode):
313314
314315 self .mode = mode
315316
316- def state_dict (self ):
317+ def state_dict (self ) -> dict :
317318 return {key : value for key , value in self .__dict__ .items () if key != "optimizer" }
318319
319- def load_state_dict (self , state_dict ) :
320+ def load_state_dict (self , state_dict : dict ) -> None :
320321 self .__dict__ .update (state_dict )
321322 self ._init_is_better (mode = self .mode )
322323
@@ -350,11 +351,11 @@ def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionar
350351 return [[1 if label in labels else 0 for label in label_dict .get_items ()] for labels in label_list ]
351352
352353
353- def log_line (log ) :
354+ def log_line (log : logging . Logger ) -> None :
354355 log .info ("-" * 100 , stacklevel = 3 )
355356
356357
357- def add_file_handler (log , output_file ) :
358+ def add_file_handler (log : logging . Logger , output_file : pathlib . Path ) -> logging . FileHandler :
358359 init_output_file (output_file .parents [0 ], output_file .name )
359360 fh = logging .FileHandler (output_file , mode = "w" , encoding = "utf-8" )
360361 fh .setLevel (logging .INFO )
@@ -368,11 +369,19 @@ def store_embeddings(
368369 data_points : Union [list [DT ], Dataset ],
369370 storage_mode : EmbeddingStorageMode ,
370371 dynamic_embeddings : Optional [list [str ]] = None ,
371- ):
372+ ) -> None :
373+ """Stores embeddings of data points in memory or on disk.
374+
375+ Args:
376+ data_points: a DataSet or list of DataPoints for which embeddings should be stored
377+ storage_mode: store in either CPU or GPU memory, or delete them if set to 'none'
378+ dynamic_embeddings: these are always deleted. If not passed, they are identified automatically.
379+ """
380+
372381 if isinstance (data_points , Dataset ):
373382 data_points = list (_iter_dataset (data_points ))
374383
375- # if memory mode option 'none' delete everything
384+ # if storage mode option 'none' delete everything
376385 if storage_mode == "none" :
377386 dynamic_embeddings = None
378387
@@ -411,3 +420,97 @@ def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]:
411420 if not all_embeddings :
412421 return None
413422 return list (set (dynamic_embeddings ))
423+
424+
425+ class TokenEntity (NamedTuple ):
426+ """Entity represented by token indices."""
427+
428+ start_token_idx : int
429+ end_token_idx : int
430+ label : str
431+ value : str = "" # text value of the entity
432+ score : float = 1.0
433+
434+
435+ class CharEntity (NamedTuple ):
436+ """Entity represented by character indices."""
437+
438+ start_char_idx : int
439+ end_char_idx : int
440+ label : str
441+ value : str
442+ score : float = 1.0
443+
444+
445+ def create_labeled_sentence_from_tokens (
446+ tokens : Union [list [Token ]], token_entities : list [TokenEntity ], type_name : str = "ner"
447+ ) -> Sentence :
448+ """Creates a new Sentence object from a list of tokens or strings and applies entity labels.
449+
450+ Tokens are recreated with the same text, but not attached to the previous sentence.
451+
452+ Args:
453+ tokens: a list of Token objects or strings - only the text is used, not any labels
454+ token_entities: a list of TokenEntity objects representing entity annotations
455+ type_name: the type of entity label to apply
456+ Returns:
457+ A labeled Sentence object
458+ """
459+ tokens_ = [token .text for token in tokens ] # create new tokens that do not already belong to a sentence
460+ sentence = Sentence (tokens_ , use_tokenizer = True )
461+ for entity in token_entities :
462+ sentence [entity .start_token_idx : entity .end_token_idx ].add_label (type_name , entity .label , score = entity .score )
463+ return sentence
464+
465+
466+ def create_labeled_sentence_from_entity_offsets (
467+ text : str ,
468+ entities : list [CharEntity ],
469+ token_limit : float = inf ,
470+ ) -> Sentence :
471+ """Creates a labeled sentence from a text and a list of entity annotations.
472+
473+ The function explicitly tokenizes the text and labels separately, ensuring entity labels are
474+ not partially split across tokens. The sentence is truncated if a token limit is set.
475+
476+ Args:
477+ text (str): The full text to be tokenized and labeled.
478+ entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
479+ format (start_char_index, end_char_index, entity_class, entity_text).
480+ token_limit: numerical value that determines the maximum token length of the sentence.
481+ use inf to not perform chunking
482+
483+ Returns:
484+ A labeled Sentence objects representing the text and entity annotations.
485+ """
486+ tokens : list [Token ] = []
487+ current_index = 0
488+ token_entities : list [TokenEntity ] = []
489+
490+ for entity in entities :
491+ if current_index < entity .start_char_idx :
492+ # add tokens before the entity
493+ sentence = Sentence (text [current_index : entity .start_char_idx ])
494+ tokens .extend (sentence )
495+
496+ # add new entity tokens
497+ start_token_idx = len (tokens )
498+ entity_sentence = Sentence (text [entity .start_char_idx : entity .end_char_idx ])
499+ end_token_idx = start_token_idx + len (entity_sentence )
500+
501+ token_entity = TokenEntity (start_token_idx , end_token_idx , entity .label , entity .value , entity .score )
502+ token_entities .append (token_entity )
503+ tokens .extend (entity_sentence )
504+
505+ current_index = entity .end_char_idx
506+
507+ # add any remaining tokens to a new chunk
508+ if current_index < len (text ):
509+ remaining_sentence = Sentence (text [current_index :])
510+ tokens .extend (remaining_sentence )
511+
512+ if isinstance (token_limit , int ) and token_limit < len (tokens ):
513+ tokens = tokens [:token_limit ]
514+ token_entities = [entity for entity in token_entities if entity .end_token_idx <= token_limit ]
515+
516+ return create_labeled_sentence_from_tokens (tokens , token_entities )
0 commit comments