44from collections import Counter
55from itertools import chain
66from tempfile import TemporaryDirectory
7- from typing import TypeVar , cast
7+ from typing import Generic , TypeVar , cast
88
99import lightning as pl
1010import numpy as np
2525logger = logging .getLogger (__name__ )
2626_RANDOM_SEED = 42
2727
28- LabelType = TypeVar ("LabelType" , list [str ], list [int ], list [list [str ]], list [list [int ]])
28+ PossibleLabels = list [str ] | list [list [str ]]
29+ LabelType = TypeVar ("LabelType" , list [str ], list [list [str ]])
2930
3031
31- class StaticModelForClassification (FinetunableStaticModel ):
32+ class StaticModelForClassification (FinetunableStaticModel , Generic [ LabelType ] ):
3233 def __init__ (
3334 self ,
3435 * ,
@@ -39,15 +40,23 @@ def __init__(
3940 out_dim : int = 2 ,
4041 pad_id : int = 0 ,
4142 token_mapping : list [int ] | None = None ,
43+ weights : torch .Tensor | None = None ,
4244 ) -> None :
4345 """Initialize a standard classifier model."""
4446 self .n_layers = n_layers
4547 self .hidden_dim = hidden_dim
4648 # Alias: Follows scikit-learn. Set to dummy classes
47- self .classes_ : list [str ] = [str ( x ) for x in range ( out_dim ) ]
49+ self .classes_ : list [str ] = ["0" , "1" ]
4850 # multilabel flag will be set based on the type of `y` passed to fit.
4951 self .multilabel : bool = False
50- super ().__init__ (vectors = vectors , out_dim = out_dim , pad_id = pad_id , tokenizer = tokenizer , token_mapping = token_mapping )
52+ super ().__init__ (
53+ vectors = vectors ,
54+ out_dim = out_dim ,
55+ pad_id = pad_id ,
56+ tokenizer = tokenizer ,
57+ token_mapping = token_mapping ,
58+ weights = weights ,
59+ )
5160
5261 @property
5362 def classes (self ) -> np .ndarray :
@@ -166,7 +175,7 @@ def fit(
166175 :param device: The device to train on. If this is "auto", the device is chosen automatically.
167176 :param X_val: The texts to be used for validation.
168177 :param y_val: The labels to be used for validation.
169- :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
178+ :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
170179 have the same length as the number of classes.
171180 :return: The fitted model.
172181 :raises ValueError: If either X_val or y_val are provided, but not both.
@@ -202,7 +211,7 @@ def fit(
202211 base_number = int (min (max (1 , (len (train_texts ) / 30 ) // 32 ), 16 ))
203212 batch_size = int (base_number * 32 )
204213 logger .info ("Batch size automatically set to %d." , batch_size )
205-
214+
206215 if class_weight is not None :
207216 if len (class_weight ) != len (self .classes_ ):
208217 raise ValueError ("class_weight must have the same length as the number of classes." )
@@ -284,11 +293,8 @@ def _initialize(self, y: LabelType) -> None:
284293
285294 :param y: The labels.
286295 :raises ValueError: If the labels are inconsistent.
287- """
288- if isinstance (y [0 ], (str , int )):
289- # Check if all labels are strings or integers.
290- if not all (isinstance (label , (str , int )) for label in y ):
291- raise ValueError ("Inconsistent label types in y. All labels must be strings or integers." )
296+ """
297+ if all (isinstance (label , str ) for label in y ):
292298 self .multilabel = False
293299 classes = sorted (set (y ))
294300 else :
@@ -330,13 +336,13 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) ->
330336 indices = [mapping [label ] for label in sample_labels ]
331337 labels_tensor [i , indices ] = 1.0
332338 else :
333- labels_tensor = torch .tensor ([self .classes_ .index (label ) for label in cast ( list [ str ], y ) ], dtype = torch .long )
339+ labels_tensor = torch .tensor ([self .classes_ .index (label ) for label in y ], dtype = torch .long )
334340 return TextDataset (tokenized , labels_tensor )
335341
336342 def _train_test_split (
337343 self ,
338344 X : list [str ],
339- y : list [ str ] | list [ list [ str ]] ,
345+ y : LabelType ,
340346 test_size : float ,
341347 ) -> tuple [list [str ], list [str ], LabelType , LabelType ]:
342348 """
@@ -384,12 +390,18 @@ def to_pipeline(self) -> StaticModelPipeline:
384390
385391
386392class _ClassifierLightningModule (pl .LightningModule ):
387- def __init__ (self , model : StaticModelForClassification , learning_rate : float , class_weight : torch .Tensor | None = None ) -> None :
393+ def __init__ (
394+ self , model : StaticModelForClassification , learning_rate : float , class_weight : torch .Tensor | None = None
395+ ) -> None :
388396 """Initialize the LightningModule."""
389397 super ().__init__ ()
390398 self .model = model
391399 self .learning_rate = learning_rate
392- self .loss_function = nn .CrossEntropyLoss (weight = class_weight ) if not model .multilabel else nn .BCEWithLogitsLoss (pos_weight = class_weight )
400+ self .loss_function = (
401+ nn .CrossEntropyLoss (weight = class_weight )
402+ if not model .multilabel
403+ else nn .BCEWithLogitsLoss (pos_weight = class_weight )
404+ )
393405
394406 def forward (self , x : torch .Tensor ) -> torch .Tensor :
395407 """Simple forward pass."""
@@ -408,10 +420,12 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
408420 x , y = batch
409421 head_out , _ = self .model (x )
410422 loss = self .loss_function (head_out , y )
423+
424+ accuracy : float
411425 if self .model .multilabel :
412426 preds = (torch .sigmoid (head_out ) > 0.5 ).float ()
413427 # Multilabel accuracy is defined as the Jaccard score averaged over samples.
414- accuracy = jaccard_score (y .cpu (), preds .cpu (), average = "samples" )
428+ accuracy = cast ( float , jaccard_score (y .cpu (), preds .cpu (), average = "samples" ) )
415429 else :
416430 accuracy = (head_out .argmax (dim = 1 ) == y ).float ().mean ()
417431 self .log ("val_loss" , loss )
0 commit comments