Skip to content

Commit 8df3aaf

Browse files
committed
Fix bugs
1 parent e27f9dc commit 8df3aaf

File tree

2 files changed

+36
-27
lines changed

2 files changed

+36
-27
lines changed

model2vec/train/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int,
2525
self.pad_id = pad_id
2626
self.out_dim = out_dim
2727
self.embed_dim = vectors.shape[1]
28+
self.vectors = vectors
2829

29-
self.vectors = torch.randn_like(vectors)
3030
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
3131
self.head = self.construct_head()
3232

33-
# Weights for
3433
weights = torch.zeros(len(vectors))
3534
weights[pad_id] = -10_000
3635
self.w = nn.Parameter(weights)
@@ -71,7 +70,7 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
7170
:return: The mean over the input ids, weighted by token weights.
7271
"""
7372
w = self.w[input_ids]
74-
w = torch.softmax(w, dim=1)
73+
w = torch.sigmoid(w)
7574
zeros = (input_ids != self.pad_id).float()
7675
w = w * zeros
7776
# Add a small epsilon to avoid division by zero
@@ -80,6 +79,7 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
8079
# Simulate actual mean
8180
# Zero out the padding
8281
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
82+
# embedded = embedded.sum(1)
8383
embedded = embedded / length[:, None]
8484

8585
return nn.functional.normalize(embedded)

model2vec/train/classifier.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
from lightning.pytorch.callbacks import Callback, EarlyStopping
11+
from lightning.pytorch.utilities.types import OptimizerLRScheduler
1112
from sklearn.model_selection import train_test_split
1213
from tokenizers import Tokenizer
1314
from torch import nn
@@ -31,8 +32,8 @@ def __init__(
3132
"""Initialize a standard classifier model."""
3233
self.n_layers = n_layers
3334
self.hidden_dim = hidden_dim
34-
# Alias: Follows scikit-learn.
35-
self.classes_: list[str] = []
35+
# Alias: Follows scikit-learn. Set to dummy classes
36+
self.classes_: list[str] = [str(x) for x in range(out_dim)]
3637
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)
3738

3839
@property
@@ -45,57 +46,53 @@ def construct_head(self) -> nn.Module:
4546
if self.n_layers == 0:
4647
return nn.Linear(self.embed_dim, self.out_dim)
4748
modules = [
48-
nn.Dropout(0.5),
4949
nn.Linear(self.embed_dim, self.hidden_dim),
50-
nn.LayerNorm(self.hidden_dim),
5150
nn.ReLU(),
5251
]
5352
for _ in range(self.n_layers - 1):
54-
modules.extend(
55-
[nn.Dropout(0.5), nn.Linear(self.hidden_dim, self.hidden_dim), nn.LayerNorm(self.hidden_dim), nn.ReLU()]
56-
)
53+
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
5754
modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])
5855

5956
for module in modules:
6057
if isinstance(module, nn.Linear):
61-
nn.init.kaiming_normal_(module.weight)
58+
nn.init.kaiming_uniform_(module.weight)
6259
nn.init.zeros_(module.bias)
6360

6461
return nn.Sequential(*modules)
6562

66-
def predict(self, texts: list[str]) -> list[str]:
63+
def predict(self, X: list[str]) -> list[str]:
6764
"""Predict a class for a set of texts."""
6865
pred: list[str] = []
69-
for batch in range(0, len(texts), 1024):
70-
logits = self._predict(texts[batch : batch + 1024])
66+
for batch in range(0, len(X), 1024):
67+
logits = self._predict(X[batch : batch + 1024])
7168
pred.extend([self.classes[idx] for idx in logits.argmax(1)])
7269

7370
return pred
7471

7572
@torch.no_grad()
76-
def _predict(self, texts: list[str]) -> torch.Tensor:
77-
input_ids = self.tokenize(texts)
73+
def _predict(self, X: list[str]) -> torch.Tensor:
74+
input_ids = self.tokenize(X)
7875
vectors, _ = self.forward(input_ids)
7976
return vectors
8077

81-
def predict_proba(self, texts: list[str]) -> np.ndarray:
78+
def predict_proba(self, X: list[str]) -> np.ndarray:
8279
"""Predict the probability of each class."""
8380
pred: list[np.ndarray] = []
84-
for batch in range(0, len(texts), 1024):
85-
logits = self._predict(texts[batch : batch + 1024])
81+
for batch in range(0, len(X), 1024):
82+
logits = self._predict(X[batch : batch + 1024])
8683
pred.append(torch.softmax(logits, dim=1).numpy())
8784

8885
return np.concatenate(pred)
8986

9087
def fit(
9188
self,
92-
texts: list[str],
93-
labels: list[str],
89+
X: list[str],
90+
y: list[str],
9491
**kwargs: Any,
9592
) -> ClassificationStaticModel:
9693
"""Fit a model."""
9794
pl.seed_everything(42)
98-
classes = sorted(set(labels))
95+
classes = sorted(set(y))
9996
self.classes_ = classes
10097

10198
if len(self.classes) != self.out_dim:
@@ -105,15 +102,15 @@ def fit(
105102
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
106103

107104
label_mapping = {label: idx for idx, label in enumerate(self.classes)}
108-
label_counts = Counter(labels)
105+
label_counts = Counter(y)
109106
if min(label_counts.values()) < 2:
110107
logger.info("Some classes have less than 2 samples. Stratification is disabled.")
111108
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
112-
texts, labels, test_size=0.1, random_state=42, shuffle=True
109+
X, y, test_size=0.1, random_state=42, shuffle=True
113110
)
114111
else:
115112
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
116-
texts, labels, test_size=0.1, random_state=42, shuffle=True, stratify=labels
113+
X, y, test_size=0.1, random_state=42, shuffle=True, stratify=y
117114
)
118115

119116
# Turn labels into a LongTensor
@@ -190,6 +187,18 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
190187

191188
return loss
192189

193-
def configure_optimizers(self) -> torch.optim.Optimizer:
190+
def configure_optimizers(self) -> OptimizerLRScheduler:
194191
"""Simple Adam optimizer."""
195-
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
192+
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
193+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
194+
optimizer,
195+
mode="min",
196+
factor=0.5,
197+
patience=3,
198+
verbose=True,
199+
min_lr=1e-6,
200+
threshold=0.03,
201+
threshold_mode="rel",
202+
)
203+
204+
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}

0 commit comments

Comments
 (0)