Skip to content

Commit 4867cb8

Browse files
authored
feat: Add support for passing weight to the classification loss functions (#260)
* add support for passing weight to the loss functions * adds test for weights and fixes issue with state dict
1 parent 06a478c commit 4867cb8

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

model2vec/train/classifier.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def fit(
137137
device: str = "auto",
138138
X_val: list[str] | None = None,
139139
y_val: LabelType | None = None,
140+
class_weight: torch.Tensor | None = None,
140141
) -> StaticModelForClassification:
141142
"""
142143
Fit a model.
@@ -164,6 +165,8 @@ def fit(
164165
:param device: The device to train on. If this is "auto", the device is chosen automatically.
165166
:param X_val: The texts to be used for validation.
166167
:param y_val: The labels to be used for validation.
168+
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
169+
have the same length as the number of classes.
167170
:return: The fitted model.
168171
:raises ValueError: If either X_val or y_val are provided, but not both.
169172
"""
@@ -198,13 +201,17 @@ def fit(
198201
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
199202
batch_size = int(base_number * 32)
200203
logger.info("Batch size automatically set to %d.", batch_size)
204+
205+
if class_weight is not None:
206+
if len(class_weight) != len(self.classes_):
207+
raise ValueError("class_weight must have the same length as the number of classes.")
201208

202209
logger.info("Preparing train dataset.")
203210
train_dataset = self._prepare_dataset(train_texts, train_labels)
204211
logger.info("Preparing validation dataset.")
205212
val_dataset = self._prepare_dataset(validation_texts, validation_labels)
206213

207-
c = _ClassifierLightningModule(self, learning_rate=learning_rate)
214+
c = _ClassifierLightningModule(self, learning_rate=learning_rate, class_weight=class_weight)
208215

209216
n_train_batches = len(train_dataset) // batch_size
210217
callbacks: list[Callback] = []
@@ -242,6 +249,9 @@ def fit(
242249

243250
state_dict = {}
244251
for weight_name, weight in best_model_weights["state_dict"].items():
252+
if "loss_function" in weight_name:
253+
# Skip the loss function class weight as its not needed for predictions
254+
continue
245255
state_dict[weight_name.removeprefix("model.")] = weight
246256

247257
self.load_state_dict(state_dict)
@@ -373,12 +383,12 @@ def to_pipeline(self) -> StaticModelPipeline:
373383

374384

375385
class _ClassifierLightningModule(pl.LightningModule):
376-
def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None:
386+
def __init__(self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
377387
"""Initialize the LightningModule."""
378388
super().__init__()
379389
self.model = model
380390
self.learning_rate = learning_rate
381-
self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss()
391+
self.loss_function = nn.CrossEntropyLoss(weight=class_weight) if not model.multilabel else nn.BCEWithLogitsLoss(pos_weight=class_weight)
382392

383393
def forward(self, x: torch.Tensor) -> torch.Tensor:
384394
"""Simple forward pass."""

tests/test_trainable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,23 @@ def test_y_val_none() -> None:
174174
model.fit(X, y, X_val=None, y_val=y_val)
175175
model.fit(X, y, X_val=None, y_val=None)
176176

177+
def test_class_weight() -> None:
178+
"""Test the class weight function."""
179+
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
180+
torch.random.manual_seed(42)
181+
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
182+
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
183+
184+
X = ["dog", "cat"]
185+
y = ["0", "1"]
186+
187+
bad_class_weight = torch.tensor([1.0])
188+
with pytest.raises(ValueError):
189+
model.fit(X, y, class_weight=bad_class_weight)
190+
191+
class_weight = torch.tensor([1.0, 2.0])
192+
model.fit(X, y, class_weight=class_weight)
193+
177194

178195
@pytest.mark.parametrize(
179196
"y_multi,y_val_multi,should_crash",

0 commit comments

Comments
 (0)