@@ -137,6 +137,7 @@ def fit(
137137 device : str = "auto" ,
138138 X_val : list [str ] | None = None ,
139139 y_val : LabelType | None = None ,
140+ class_weight : torch .Tensor | None = None ,
140141 ) -> StaticModelForClassification :
141142 """
142143 Fit a model.
@@ -164,6 +165,8 @@ def fit(
164165 :param device: The device to train on. If this is "auto", the device is chosen automatically.
165166 :param X_val: The texts to be used for validation.
166167 :param y_val: The labels to be used for validation.
168+ :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
169+ have the same length as the number of classes.
167170 :return: The fitted model.
168171 :raises ValueError: If either X_val or y_val are provided, but not both.
169172 """
@@ -198,13 +201,17 @@ def fit(
198201 base_number = int (min (max (1 , (len (train_texts ) / 30 ) // 32 ), 16 ))
199202 batch_size = int (base_number * 32 )
200203 logger .info ("Batch size automatically set to %d." , batch_size )
204+
205+ if class_weight is not None :
206+ if len (class_weight ) != len (self .classes_ ):
207+ raise ValueError ("class_weight must have the same length as the number of classes." )
201208
202209 logger .info ("Preparing train dataset." )
203210 train_dataset = self ._prepare_dataset (train_texts , train_labels )
204211 logger .info ("Preparing validation dataset." )
205212 val_dataset = self ._prepare_dataset (validation_texts , validation_labels )
206213
207- c = _ClassifierLightningModule (self , learning_rate = learning_rate )
214+ c = _ClassifierLightningModule (self , learning_rate = learning_rate , class_weight = class_weight )
208215
209216 n_train_batches = len (train_dataset ) // batch_size
210217 callbacks : list [Callback ] = []
@@ -242,6 +249,9 @@ def fit(
242249
243250 state_dict = {}
244251 for weight_name , weight in best_model_weights ["state_dict" ].items ():
252+ if "loss_function" in weight_name :
253+ # Skip the loss function class weight as its not needed for predictions
254+ continue
245255 state_dict [weight_name .removeprefix ("model." )] = weight
246256
247257 self .load_state_dict (state_dict )
@@ -373,12 +383,12 @@ def to_pipeline(self) -> StaticModelPipeline:
373383
374384
375385class _ClassifierLightningModule (pl .LightningModule ):
376- def __init__ (self , model : StaticModelForClassification , learning_rate : float ) -> None :
386+ def __init__ (self , model : StaticModelForClassification , learning_rate : float , class_weight : torch . Tensor | None = None ) -> None :
377387 """Initialize the LightningModule."""
378388 super ().__init__ ()
379389 self .model = model
380390 self .learning_rate = learning_rate
381- self .loss_function = nn .CrossEntropyLoss () if not model .multilabel else nn .BCEWithLogitsLoss ()
391+ self .loss_function = nn .CrossEntropyLoss (weight = class_weight ) if not model .multilabel else nn .BCEWithLogitsLoss (pos_weight = class_weight )
382392
383393 def forward (self , x : torch .Tensor ) -> torch .Tensor :
384394 """Simple forward pass."""
0 commit comments