88import numpy as np
99import torch
1010from lightning .pytorch .callbacks import Callback , EarlyStopping
11+ from lightning .pytorch .utilities .types import OptimizerLRScheduler
1112from sklearn .model_selection import train_test_split
1213from tokenizers import Tokenizer
1314from torch import nn
@@ -31,8 +32,8 @@ def __init__(
3132 """Initialize a standard classifier model."""
3233 self .n_layers = n_layers
3334 self .hidden_dim = hidden_dim
34- # Alias: Follows scikit-learn.
35- self .classes_ : list [str ] = []
35+ # Alias: Follows scikit-learn. Set to dummy classes
36+ self .classes_ : list [str ] = [str ( x ) for x in range ( out_dim ) ]
3637 super ().__init__ (vectors = vectors , out_dim = out_dim , pad_id = pad_id , tokenizer = tokenizer )
3738
3839 @property
@@ -45,57 +46,53 @@ def construct_head(self) -> nn.Module:
4546 if self .n_layers == 0 :
4647 return nn .Linear (self .embed_dim , self .out_dim )
4748 modules = [
48- nn .Dropout (0.5 ),
4949 nn .Linear (self .embed_dim , self .hidden_dim ),
50- nn .LayerNorm (self .hidden_dim ),
5150 nn .ReLU (),
5251 ]
5352 for _ in range (self .n_layers - 1 ):
54- modules .extend (
55- [nn .Dropout (0.5 ), nn .Linear (self .hidden_dim , self .hidden_dim ), nn .LayerNorm (self .hidden_dim ), nn .ReLU ()]
56- )
53+ modules .extend ([nn .Linear (self .hidden_dim , self .hidden_dim ), nn .ReLU ()])
5754 modules .extend ([nn .Linear (self .hidden_dim , self .out_dim )])
5855
5956 for module in modules :
6057 if isinstance (module , nn .Linear ):
61- nn .init .kaiming_normal_ (module .weight )
58+ nn .init .kaiming_uniform_ (module .weight )
6259 nn .init .zeros_ (module .bias )
6360
6461 return nn .Sequential (* modules )
6562
66- def predict (self , texts : list [str ]) -> list [str ]:
63+ def predict (self , X : list [str ]) -> list [str ]:
6764 """Predict a class for a set of texts."""
6865 pred : list [str ] = []
69- for batch in range (0 , len (texts ), 1024 ):
70- logits = self ._predict (texts [batch : batch + 1024 ])
66+ for batch in range (0 , len (X ), 1024 ):
67+ logits = self ._predict (X [batch : batch + 1024 ])
7168 pred .extend ([self .classes [idx ] for idx in logits .argmax (1 )])
7269
7370 return pred
7471
7572 @torch .no_grad ()
76- def _predict (self , texts : list [str ]) -> torch .Tensor :
77- input_ids = self .tokenize (texts )
73+ def _predict (self , X : list [str ]) -> torch .Tensor :
74+ input_ids = self .tokenize (X )
7875 vectors , _ = self .forward (input_ids )
7976 return vectors
8077
81- def predict_proba (self , texts : list [str ]) -> np .ndarray :
78+ def predict_proba (self , X : list [str ]) -> np .ndarray :
8279 """Predict the probability of each class."""
8380 pred : list [np .ndarray ] = []
84- for batch in range (0 , len (texts ), 1024 ):
85- logits = self ._predict (texts [batch : batch + 1024 ])
81+ for batch in range (0 , len (X ), 1024 ):
82+ logits = self ._predict (X [batch : batch + 1024 ])
8683 pred .append (torch .softmax (logits , dim = 1 ).numpy ())
8784
8885 return np .concatenate (pred )
8986
9087 def fit (
9188 self ,
92- texts : list [str ],
93- labels : list [str ],
89+ X : list [str ],
90+ y : list [str ],
9491 ** kwargs : Any ,
9592 ) -> ClassificationStaticModel :
9693 """Fit a model."""
9794 pl .seed_everything (42 )
98- classes = sorted (set (labels ))
95+ classes = sorted (set (y ))
9996 self .classes_ = classes
10097
10198 if len (self .classes ) != self .out_dim :
@@ -105,15 +102,15 @@ def fit(
105102 self .embeddings = nn .Embedding .from_pretrained (self .vectors .clone (), freeze = False , padding_idx = self .pad_id )
106103
107104 label_mapping = {label : idx for idx , label in enumerate (self .classes )}
108- label_counts = Counter (labels )
105+ label_counts = Counter (y )
109106 if min (label_counts .values ()) < 2 :
110107 logger .info ("Some classes have less than 2 samples. Stratification is disabled." )
111108 train_texts , validation_texts , train_labels , validation_labels = train_test_split (
112- texts , labels , test_size = 0.1 , random_state = 42 , shuffle = True
109+ X , y , test_size = 0.1 , random_state = 42 , shuffle = True
113110 )
114111 else :
115112 train_texts , validation_texts , train_labels , validation_labels = train_test_split (
116- texts , labels , test_size = 0.1 , random_state = 42 , shuffle = True , stratify = labels
113+ X , y , test_size = 0.1 , random_state = 42 , shuffle = True , stratify = y
117114 )
118115
119116 # Turn labels into a LongTensor
@@ -190,6 +187,18 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
190187
191188 return loss
192189
193- def configure_optimizers (self ) -> torch . optim . Optimizer :
190+ def configure_optimizers (self ) -> OptimizerLRScheduler :
194191 """Simple Adam optimizer."""
195- return torch .optim .Adam (self .model .parameters (), lr = 1e-3 )
192+ optimizer = torch .optim .Adam (self .model .parameters (), lr = 1e-3 )
193+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
194+ optimizer ,
195+ mode = "min" ,
196+ factor = 0.5 ,
197+ patience = 3 ,
198+ verbose = True ,
199+ min_lr = 1e-6 ,
200+ threshold = 0.03 ,
201+ threshold_mode = "rel" ,
202+ )
203+
204+ return {"optimizer" : optimizer , "lr_scheduler" : {"scheduler" : scheduler , "monitor" : "val_loss" }}
0 commit comments