@@ -88,79 +88,96 @@ def fit(
8888 self ,
8989 X : list [str ],
9090 y : list [str ],
91- ** kwargs : Any ,
91+ learning_rate : float = 1e-3 ,
92+ batch_size : int = 32 ,
93+ early_stopping_patience : int | None = 25 ,
94+ test_size : float = 0.1 ,
9295 ) -> ClassificationStaticModel :
9396 """Fit a model."""
9497 pl .seed_everything (42 )
95- classes = sorted (set (y ))
96- self .classes_ = classes
97-
98- if len (self .classes ) != self .out_dim :
99- self .out_dim = len (self .classes )
100-
101- self .head = self .construct_head ()
102- self .embeddings = nn .Embedding .from_pretrained (self .vectors .clone (), freeze = False , padding_idx = self .pad_id )
103-
104- label_mapping = {label : idx for idx , label in enumerate (self .classes )}
105- label_counts = Counter (y )
106- if min (label_counts .values ()) < 2 :
107- logger .info ("Some classes have less than 2 samples. Stratification is disabled." )
108- train_texts , validation_texts , train_labels , validation_labels = train_test_split (
109- X , y , test_size = 0.1 , random_state = 42 , shuffle = True
110- )
111- else :
112- train_texts , validation_texts , train_labels , validation_labels = train_test_split (
113- X , y , test_size = 0.1 , random_state = 42 , shuffle = True , stratify = y
114- )
98+ self ._initialize (y )
11599
116- # Turn labels into a LongTensor
117- train_tokenized : list [list [int ]] = [
118- encoding .ids for encoding in self .tokenizer .encode_batch_fast (train_texts , add_special_tokens = False )
119- ]
120- train_labels_tensor = torch .Tensor ([label_mapping [label ] for label in train_labels ]).long ()
121- train_dataset = TextDataset (train_tokenized , train_labels_tensor )
100+ train_texts , validation_texts , train_labels , validation_labels = self ._train_test_split (
101+ X , y , test_size = test_size
102+ )
122103
123- val_tokenized : list [list [int ]] = [
124- encoding .ids for encoding in self .tokenizer .encode_batch_fast (validation_texts , add_special_tokens = False )
125- ]
126- val_labels_tensor = torch .Tensor ([label_mapping [label ] for label in validation_labels ]).long ()
127- val_dataset = TextDataset (val_tokenized , val_labels_tensor )
104+ train_dataset = self ._prepare_dataset (train_texts , train_labels )
105+ val_dataset = self ._prepare_dataset (validation_texts , validation_labels )
128106
129- c = ClassifierLightningModule (self )
107+ c = ClassifierLightningModule (self , learning_rate = learning_rate )
130108
131- batch_size = 32
132109 n_train_batches = len (train_dataset ) // batch_size
133- callbacks : list [Callback ] = [EarlyStopping (monitor = "val_accuracy" , mode = "max" , patience = 5 )]
110+ callbacks : list [Callback ] = []
111+ if early_stopping_patience is not None :
112+ callback = EarlyStopping (monitor = "val_accuracy" , mode = "max" , patience = early_stopping_patience )
113+ callbacks .append (callback )
114+
134115 if n_train_batches < 250 :
135- trainer = pl .Trainer (max_epochs = 500 , callbacks = callbacks , check_val_every_n_epoch = 1 )
116+ val_check_interval = None
117+ check_val_every_epoch = True
136118 else :
137119 val_check_interval = max (250 , 2 * len (val_dataset ) // batch_size )
138- trainer = pl .Trainer (
139- max_epochs = 500 , callbacks = callbacks , val_check_interval = val_check_interval , check_val_every_n_epoch = None
140- )
120+ check_val_every_epoch = False
121+ trainer = pl .Trainer (
122+ max_epochs = 500 ,
123+ callbacks = callbacks ,
124+ val_check_interval = val_check_interval ,
125+ check_val_every_n_epoch = check_val_every_epoch ,
126+ )
141127
142128 trainer .fit (
143129 c ,
144130 train_dataloaders = train_dataset .to_dataloader (shuffle = True , batch_size = batch_size ),
145131 val_dataloaders = val_dataset .to_dataloader (shuffle = False , batch_size = batch_size ),
146132 )
147133 best_model_path = trainer .checkpoint_callback .best_model_path # type: ignore
134+ best_model_weights = torch .load (best_model_path , weights_only = True )
148135
149- state_dict = {
150- k .removeprefix ("model." ): v for k , v in torch .load (best_model_path , weights_only = True )["state_dict" ].items ()
151- }
152- self .load_state_dict (state_dict )
136+ state_dict = {}
137+ for weight_name , weight in best_model_weights ["state_dict" ].items ():
138+ state_dict [weight_name .removeprefix ("model." )] = weight
153139
140+ self .load_state_dict (state_dict )
154141 self .eval ()
155142
156143 return self
157144
145+ def _initialize (self , y : list [str ]) -> None :
146+ """Sets the out dimensionality, the classes and initializes the head."""
147+ classes = sorted (set (y ))
148+ self .classes_ = classes
149+
150+ if len (self .classes ) != self .out_dim :
151+ self .out_dim = len (self .classes )
152+
153+ self .head = self .construct_head ()
154+ self .embeddings = nn .Embedding .from_pretrained (self .vectors .clone (), freeze = False , padding_idx = self .pad_id )
155+
156+ def _prepare_dataset (self , X : list [str ], y : list [str ]) -> TextDataset :
157+ """Prepare a dataset."""
158+ tokenized : list [list [int ]] = [
159+ encoding .ids for encoding in self .tokenizer .encode_batch_fast (X , add_special_tokens = False )
160+ ]
161+ labels_tensor = torch .Tensor ([self .classes .index (label ) for label in y ]).long ()
162+ return TextDataset (tokenized , labels_tensor )
163+
164+ def _train_test_split (
165+ self , X : list [str ], y : list [str ], test_size : float
166+ ) -> tuple [list [str ], list [str ], list [str ], list [str ]]:
167+ """Split the data."""
168+ label_counts = Counter (y )
169+ if min (label_counts .values ()) < 2 :
170+ logger .info ("Some classes have less than 2 samples. Stratification is disabled." )
171+ return train_test_split (X , y , test_size = 0.1 , random_state = 42 , shuffle = True )
172+ return train_test_split (X , y , test_size = 0.1 , random_state = 42 , shuffle = True , stratify = y )
173+
158174
159175class ClassifierLightningModule (pl .LightningModule ):
160- def __init__ (self , model : ClassificationStaticModel ) -> None :
176+ def __init__ (self , model : ClassificationStaticModel , learning_rate : float ) -> None :
161177 """Initialize the lightningmodule."""
162178 super ().__init__ ()
163179 self .model = model
180+ self .learning_rate = learning_rate
164181
165182 def forward (self , x : torch .Tensor ) -> torch .Tensor :
166183 """Simple forward pass."""
0 commit comments