Skip to content

Commit ae592bf

Browse files
authored
Merge pull request #3570 from MattGPT-ai/add-sentence-labeler
Add sentence labeler
2 parents f97264a + 082e845 commit ae592bf

File tree

6 files changed

+438
-40
lines changed

6 files changed

+438
-40
lines changed

flair/class_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
import inspect
33
from collections.abc import Iterable
44
from types import ModuleType
5-
from typing import Any, Optional, TypeVar, Union, overload
5+
from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload
6+
67

78
T = TypeVar("T")
89

910

10-
def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]:
11+
class StringLike(Protocol):
12+
def __str__(self) -> str: ...
13+
14+
15+
def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
1116
for subclass in cls.__subclasses__():
1217
yield from get_non_abstract_subclasses(subclass)
1318
if inspect.isabstract(subclass):

flair/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def __init__(
565565
head_id: Optional[int] = None,
566566
whitespace_after: int = 1,
567567
start_position: int = 0,
568-
sentence=None,
568+
sentence: Optional["Sentence"] = None,
569569
) -> None:
570570
super().__init__(sentence=sentence)
571571

flair/training_utils.py

Lines changed: 139 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import logging
2+
import pathlib
23
import random
34
from collections import defaultdict
45
from enum import Enum
56
from functools import reduce
67
from math import inf
78
from pathlib import Path
8-
from typing import Literal, Optional, Union
9+
from typing import Literal, NamedTuple, Optional, Union
910

11+
from numpy import ndarray
1012
from scipy.stats import pearsonr, spearmanr
1113
from sklearn.metrics import mean_absolute_error, mean_squared_error
1214
from torch.optim import Optimizer
1315
from torch.utils.data import Dataset
1416

1517
import flair
16-
from flair.data import DT, Dictionary, Sentence, _iter_dataset
18+
from flair.class_utils import StringLike
19+
from flair.data import DT, Dictionary, Sentence, Token, _iter_dataset
1720

1821
EmbeddingStorageMode = Literal["none", "cpu", "gpu"]
19-
log = logging.getLogger("flair")
22+
MinMax = Literal["min", "max"]
23+
logger = logging.getLogger("flair")
2024

2125

2226
class Result:
@@ -33,7 +37,7 @@ def __init__(
3337
self.main_score: float = main_score
3438
self.scores = scores
3539
self.detailed_results: str = detailed_results
36-
self.classification_report = classification_report
40+
self.classification_report = classification_report if classification_report is not None else {}
3741

3842
@property
3943
def loss(self):
@@ -44,13 +48,13 @@ def __str__(self) -> str:
4448

4549

4650
class MetricRegression:
47-
def __init__(self, name) -> None:
51+
def __init__(self, name: str) -> None:
4852
self.name = name
4953

5054
self.true: list[float] = []
5155
self.pred: list[float] = []
5256

53-
def mean_squared_error(self):
57+
def mean_squared_error(self) -> Union[float, ndarray]:
5458
return mean_squared_error(self.true, self.pred)
5559

5660
def mean_absolute_error(self):
@@ -62,22 +66,18 @@ def pearsonr(self):
6266
def spearmanr(self):
6367
return spearmanr(self.true, self.pred)[0]
6468

65-
# dummy return to fulfill trainer.train() needs
66-
def micro_avg_f_score(self):
67-
return self.mean_squared_error()
68-
69-
def to_tsv(self):
69+
def to_tsv(self) -> str:
7070
return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}"
7171

7272
@staticmethod
73-
def tsv_header(prefix=None):
73+
def tsv_header(prefix: StringLike = None) -> str:
7474
if prefix:
7575
return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN"
7676

7777
return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN"
7878

7979
@staticmethod
80-
def to_empty_tsv():
80+
def to_empty_tsv() -> str:
8181
return "\t_\t_\t_\t_"
8282

8383
def __str__(self) -> str:
@@ -101,13 +101,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
101101
self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
102102
self.number_of_weights = number_of_weights
103103

104-
def extract_weights(self, state_dict, iteration):
104+
def extract_weights(self, state_dict: dict, iteration: int) -> None:
105105
for key in state_dict:
106106
vec = state_dict[key]
107-
# print(vec)
108107
try:
109108
weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size())))
110-
except Exception:
109+
except Exception as e:
110+
logger.debug(e)
111111
continue
112112

113113
if key not in self.weights_dict:
@@ -195,15 +195,15 @@ class AnnealOnPlateau:
195195
def __init__(
196196
self,
197197
optimizer,
198-
mode="min",
199-
aux_mode="min",
200-
factor=0.1,
201-
patience=10,
202-
initial_extra_patience=0,
203-
verbose=False,
204-
cooldown=0,
205-
min_lr=0,
206-
eps=1e-8,
198+
mode: MinMax = "min",
199+
aux_mode: MinMax = "min",
200+
factor: float = 0.1,
201+
patience: int = 10,
202+
initial_extra_patience: int = 0,
203+
verbose: bool = False,
204+
cooldown: int = 0,
205+
min_lr: float = 0.0,
206+
eps: float = 1e-8,
207207
) -> None:
208208
if factor >= 1.0:
209209
raise ValueError("Factor should be < 1.0.")
@@ -214,6 +214,7 @@ def __init__(
214214
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
215215
self.optimizer = optimizer
216216

217+
self.min_lrs: list[float]
217218
if isinstance(min_lr, (list, tuple)):
218219
if len(min_lr) != len(optimizer.param_groups):
219220
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
@@ -231,7 +232,7 @@ def __init__(
231232
self.best = None
232233
self.best_aux = None
233234
self.num_bad_epochs = None
234-
self.mode_worse = None # the worse value for the chosen mode
235+
self.mode_worse: Optional[float] = None # the worse value for the chosen mode
235236
self.eps = eps
236237
self.last_epoch = 0
237238
self._init_is_better(mode=mode)
@@ -258,7 +259,7 @@ def step(self, metric, auxiliary_metric=None) -> bool:
258259
if self.mode == "max" and current > self.best:
259260
is_better = True
260261

261-
if current == self.best and auxiliary_metric:
262+
if current == self.best and auxiliary_metric is not None:
262263
current_aux = float(auxiliary_metric)
263264
if self.aux_mode == "min" and current_aux < self.best_aux:
264265
is_better = True
@@ -289,20 +290,20 @@ def step(self, metric, auxiliary_metric=None) -> bool:
289290

290291
return reduce_learning_rate
291292

292-
def _reduce_lr(self, epoch):
293+
def _reduce_lr(self, epoch: int) -> None:
293294
for i, param_group in enumerate(self.optimizer.param_groups):
294295
old_lr = float(param_group["lr"])
295296
new_lr = max(old_lr * self.factor, self.min_lrs[i])
296297
if old_lr - new_lr > self.eps:
297298
param_group["lr"] = new_lr
298299
if self.verbose:
299-
log.info(f" - reducing learning rate of group {epoch} to {new_lr}")
300+
logger.info(f" - reducing learning rate of group {epoch} to {new_lr}")
300301

301302
@property
302303
def in_cooldown(self):
303304
return self.cooldown_counter > 0
304305

305-
def _init_is_better(self, mode):
306+
def _init_is_better(self, mode: MinMax) -> None:
306307
if mode not in {"min", "max"}:
307308
raise ValueError("mode " + mode + " is unknown!")
308309

@@ -313,10 +314,10 @@ def _init_is_better(self, mode):
313314

314315
self.mode = mode
315316

316-
def state_dict(self):
317+
def state_dict(self) -> dict:
317318
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
318319

319-
def load_state_dict(self, state_dict):
320+
def load_state_dict(self, state_dict: dict) -> None:
320321
self.__dict__.update(state_dict)
321322
self._init_is_better(mode=self.mode)
322323

@@ -350,11 +351,11 @@ def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionar
350351
return [[1 if label in labels else 0 for label in label_dict.get_items()] for labels in label_list]
351352

352353

353-
def log_line(log):
354+
def log_line(log: logging.Logger) -> None:
354355
log.info("-" * 100, stacklevel=3)
355356

356357

357-
def add_file_handler(log, output_file):
358+
def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.FileHandler:
358359
init_output_file(output_file.parents[0], output_file.name)
359360
fh = logging.FileHandler(output_file, mode="w", encoding="utf-8")
360361
fh.setLevel(logging.INFO)
@@ -368,11 +369,19 @@ def store_embeddings(
368369
data_points: Union[list[DT], Dataset],
369370
storage_mode: EmbeddingStorageMode,
370371
dynamic_embeddings: Optional[list[str]] = None,
371-
):
372+
) -> None:
373+
"""Stores embeddings of data points in memory or on disk.
374+
375+
Args:
376+
data_points: a DataSet or list of DataPoints for which embeddings should be stored
377+
storage_mode: store in either CPU or GPU memory, or delete them if set to 'none'
378+
dynamic_embeddings: these are always deleted. If not passed, they are identified automatically.
379+
"""
380+
372381
if isinstance(data_points, Dataset):
373382
data_points = list(_iter_dataset(data_points))
374383

375-
# if memory mode option 'none' delete everything
384+
# if storage mode option 'none' delete everything
376385
if storage_mode == "none":
377386
dynamic_embeddings = None
378387

@@ -411,3 +420,97 @@ def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]:
411420
if not all_embeddings:
412421
return None
413422
return list(set(dynamic_embeddings))
423+
424+
425+
class TokenEntity(NamedTuple):
426+
"""Entity represented by token indices."""
427+
428+
start_token_idx: int
429+
end_token_idx: int
430+
label: str
431+
value: str = "" # text value of the entity
432+
score: float = 1.0
433+
434+
435+
class CharEntity(NamedTuple):
436+
"""Entity represented by character indices."""
437+
438+
start_char_idx: int
439+
end_char_idx: int
440+
label: str
441+
value: str
442+
score: float = 1.0
443+
444+
445+
def create_labeled_sentence_from_tokens(
446+
tokens: Union[list[Token]], token_entities: list[TokenEntity], type_name: str = "ner"
447+
) -> Sentence:
448+
"""Creates a new Sentence object from a list of tokens or strings and applies entity labels.
449+
450+
Tokens are recreated with the same text, but not attached to the previous sentence.
451+
452+
Args:
453+
tokens: a list of Token objects or strings - only the text is used, not any labels
454+
token_entities: a list of TokenEntity objects representing entity annotations
455+
type_name: the type of entity label to apply
456+
Returns:
457+
A labeled Sentence object
458+
"""
459+
tokens_ = [token.text for token in tokens] # create new tokens that do not already belong to a sentence
460+
sentence = Sentence(tokens_, use_tokenizer=True)
461+
for entity in token_entities:
462+
sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score)
463+
return sentence
464+
465+
466+
def create_labeled_sentence_from_entity_offsets(
467+
text: str,
468+
entities: list[CharEntity],
469+
token_limit: float = inf,
470+
) -> Sentence:
471+
"""Creates a labeled sentence from a text and a list of entity annotations.
472+
473+
The function explicitly tokenizes the text and labels separately, ensuring entity labels are
474+
not partially split across tokens. The sentence is truncated if a token limit is set.
475+
476+
Args:
477+
text (str): The full text to be tokenized and labeled.
478+
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
479+
format (start_char_index, end_char_index, entity_class, entity_text).
480+
token_limit: numerical value that determines the maximum token length of the sentence.
481+
use inf to not perform chunking
482+
483+
Returns:
484+
A labeled Sentence objects representing the text and entity annotations.
485+
"""
486+
tokens: list[Token] = []
487+
current_index = 0
488+
token_entities: list[TokenEntity] = []
489+
490+
for entity in entities:
491+
if current_index < entity.start_char_idx:
492+
# add tokens before the entity
493+
sentence = Sentence(text[current_index : entity.start_char_idx])
494+
tokens.extend(sentence)
495+
496+
# add new entity tokens
497+
start_token_idx = len(tokens)
498+
entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx])
499+
end_token_idx = start_token_idx + len(entity_sentence)
500+
501+
token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score)
502+
token_entities.append(token_entity)
503+
tokens.extend(entity_sentence)
504+
505+
current_index = entity.end_char_idx
506+
507+
# add any remaining tokens to a new chunk
508+
if current_index < len(text):
509+
remaining_sentence = Sentence(text[current_index:])
510+
tokens.extend(remaining_sentence)
511+
512+
if isinstance(token_limit, int) and token_limit < len(tokens):
513+
tokens = tokens[:token_limit]
514+
token_entities = [entity for entity in token_entities if entity.end_token_idx <= token_limit]
515+
516+
return create_labeled_sentence_from_tokens(tokens, token_entities)

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ types-Deprecated>=1.2.9.2
1212
types-requests>=2.28.11.17
1313
types-tabulate>=0.9.0.2
1414
pyab3p
15-
transformers!=4.40.1,!=4.40.0
15+
transformers!=4.40.1,!=4.40.0

0 commit comments

Comments
 (0)