Skip to content

Commit b608081

Browse files
committed
wip
1 parent 75bdf4f commit b608081

File tree

6 files changed

+69
-35
lines changed

6 files changed

+69
-35
lines changed

model2vec/inference/model.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from pathlib import Path
55
from tempfile import TemporaryDirectory
6-
from typing import Sequence, TypeVar
6+
from typing import Sequence, TypeVar, cast
77

88
import huggingface_hub
99
import numpy as np
@@ -273,14 +273,14 @@ def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> Non
273273
)
274274

275275

276-
def _is_multi_label_shaped(y: LabelType) -> bool:
276+
def _is_multi_label_shaped(y: list[int] | list[str] | list[list[int]] | list[list[str]]) -> bool:
277277
"""Check if the labels are in a multi-label shape."""
278278
return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))
279279

280280

281281
def evaluate_single_or_multi_label(
282282
predictions: np.ndarray,
283-
y: LabelType,
283+
y: list[int] | list[str] | list[list[int]] | list[list[str]],
284284
output_dict: bool = False,
285285
) -> str | dict[str, dict[str, float]]:
286286
"""
@@ -292,16 +292,22 @@ def evaluate_single_or_multi_label(
292292
:return: A classification report.
293293
"""
294294
if _is_multi_label_shaped(y):
295+
# Cast because the type checker doesn't understand that y is a list of lists.
296+
y = cast(list[list[str]] | list[list[int]], y)
295297
classes = sorted(set([label for labels in y for label in labels]))
296298
mlb = MultiLabelBinarizer(classes=classes)
297-
y = mlb.fit_transform(y)
298-
predictions = mlb.transform(predictions)
299-
elif isinstance(y[0], (str, int)):
300-
classes = sorted(set(y))
299+
y_transformed = mlb.fit_transform(y)
300+
predictions_transformed = mlb.transform(predictions)
301+
else:
302+
if all(isinstance(label, (str, int)) for label in y):
303+
y = cast(list[str] | list[int], y)
304+
classes = sorted(set(y))
305+
y_transformed = np.array(y)
306+
predictions_transformed = np.array(predictions)
301307

302308
report = classification_report(
303-
y,
304-
predictions,
309+
y_transformed,
310+
predictions_transformed,
305311
output_dict=output_dict,
306312
zero_division=0,
307313
)

model2vec/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.
452452
id_list_remapped = [self.token_mapping.get(token_id, token_id) for token_id in id_list]
453453
emb = self.embedding[id_list_remapped]
454454
if self.weights is not None:
455-
emb = (emb * self.weights[id_list][:, None])
455+
emb = emb * self.weights[id_list][:, None]
456456
emb = emb.mean(axis=0)
457457

458458
out.append(emb)
@@ -514,4 +514,4 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
514514

515515
embeddings, tokenizer, config = load_local_model(path)
516516

517-
return StaticModel(embeddings, tokenizer, config)
517+
return StaticModel(embeddings, tokenizer, config=config)

model2vec/train/base.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616

1717

1818
class FinetunableStaticModel(nn.Module):
19-
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0, token_mapping: list[int] | None = None) -> None:
19+
def __init__(
20+
self,
21+
*,
22+
vectors: torch.Tensor,
23+
tokenizer: Tokenizer,
24+
out_dim: int = 2,
25+
pad_id: int = 0,
26+
token_mapping: list[int] | None = None,
27+
weights: torch.Tensor | None = None,
28+
) -> None:
2029
"""
2130
Initialize a trainable StaticModel from a StaticModel.
2231
@@ -45,7 +54,7 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int
4554
self.token_mapping = nn.Parameter(self.token_mapping, requires_grad=False)
4655
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
4756
self.head = self.construct_head()
48-
self.w = self.construct_weights()
57+
self.w = self.construct_weights() if weights is None else nn.Parameter(weights, requires_grad=True)
4958
self.tokenizer = tokenizer
5059

5160
def construct_weights(self) -> nn.Parameter:
@@ -70,6 +79,7 @@ def from_pretrained(
7079
def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType:
7180
"""Load the model from a static model."""
7281
model.embedding = np.nan_to_num(model.embedding)
82+
weights = torch.from_numpy(model.weights) if model.weights is not None else None
7383
embeddings_converted = torch.from_numpy(model.embedding)
7484
if model.token_mapping is not None:
7585
token_mapping = [i for _, i in sorted(model.token_mapping.items(), key=lambda x: x[0])]
@@ -81,6 +91,7 @@ def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int
8191
out_dim=out_dim,
8292
tokenizer=model.tokenizer,
8393
token_mapping=token_mapping,
94+
weights=weights,
8495
**kwargs,
8596
)
8697

@@ -139,7 +150,9 @@ def to_static_model(self) -> StaticModel:
139150
w = torch.sigmoid(self.w).detach().cpu().numpy()
140151
token_mapping = {i: int(token_id) for i, token_id in enumerate(self.token_mapping.tolist())}
141152

142-
return StaticModel(vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=True, token_mapping=token_mapping)
153+
return StaticModel(
154+
vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=True, token_mapping=token_mapping
155+
)
143156

144157

145158
class TextDataset(Dataset):
@@ -169,7 +182,7 @@ def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor,
169182
"""Collate function."""
170183
texts, targets = zip(*batch)
171184

172-
tensors = [torch.LongTensor(x) for x in texts]
185+
tensors: list[torch.Tensor] = [torch.LongTensor(x) for x in texts]
173186
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
174187

175188
return padded, torch.stack(targets)

model2vec/train/classifier.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import Counter
55
from itertools import chain
66
from tempfile import TemporaryDirectory
7-
from typing import TypeVar, cast
7+
from typing import Generic, TypeVar, cast
88

99
import lightning as pl
1010
import numpy as np
@@ -25,10 +25,11 @@
2525
logger = logging.getLogger(__name__)
2626
_RANDOM_SEED = 42
2727

28-
LabelType = TypeVar("LabelType", list[str], list[int], list[list[str]], list[list[int]])
28+
PossibleLabels = list[str] | list[list[str]]
29+
LabelType = TypeVar("LabelType", list[str], list[list[str]])
2930

3031

31-
class StaticModelForClassification(FinetunableStaticModel):
32+
class StaticModelForClassification(FinetunableStaticModel, Generic[LabelType]):
3233
def __init__(
3334
self,
3435
*,
@@ -39,15 +40,23 @@ def __init__(
3940
out_dim: int = 2,
4041
pad_id: int = 0,
4142
token_mapping: list[int] | None = None,
43+
weights: torch.Tensor | None = None,
4244
) -> None:
4345
"""Initialize a standard classifier model."""
4446
self.n_layers = n_layers
4547
self.hidden_dim = hidden_dim
4648
# Alias: Follows scikit-learn. Set to dummy classes
47-
self.classes_: list[str] = [str(x) for x in range(out_dim)]
49+
self.classes_: list[str] = ["0", "1"]
4850
# multilabel flag will be set based on the type of `y` passed to fit.
4951
self.multilabel: bool = False
50-
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer, token_mapping=token_mapping)
52+
super().__init__(
53+
vectors=vectors,
54+
out_dim=out_dim,
55+
pad_id=pad_id,
56+
tokenizer=tokenizer,
57+
token_mapping=token_mapping,
58+
weights=weights,
59+
)
5160

5261
@property
5362
def classes(self) -> np.ndarray:
@@ -166,7 +175,7 @@ def fit(
166175
:param device: The device to train on. If this is "auto", the device is chosen automatically.
167176
:param X_val: The texts to be used for validation.
168177
:param y_val: The labels to be used for validation.
169-
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
178+
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
170179
have the same length as the number of classes.
171180
:return: The fitted model.
172181
:raises ValueError: If either X_val or y_val are provided, but not both.
@@ -202,7 +211,7 @@ def fit(
202211
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
203212
batch_size = int(base_number * 32)
204213
logger.info("Batch size automatically set to %d.", batch_size)
205-
214+
206215
if class_weight is not None:
207216
if len(class_weight) != len(self.classes_):
208217
raise ValueError("class_weight must have the same length as the number of classes.")
@@ -284,11 +293,8 @@ def _initialize(self, y: LabelType) -> None:
284293
285294
:param y: The labels.
286295
:raises ValueError: If the labels are inconsistent.
287-
"""
288-
if isinstance(y[0], (str, int)):
289-
# Check if all labels are strings or integers.
290-
if not all(isinstance(label, (str, int)) for label in y):
291-
raise ValueError("Inconsistent label types in y. All labels must be strings or integers.")
296+
"""
297+
if all(isinstance(label, str) for label in y):
292298
self.multilabel = False
293299
classes = sorted(set(y))
294300
else:
@@ -330,13 +336,13 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) ->
330336
indices = [mapping[label] for label in sample_labels]
331337
labels_tensor[i, indices] = 1.0
332338
else:
333-
labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long)
339+
labels_tensor = torch.tensor([self.classes_.index(label) for label in y], dtype=torch.long)
334340
return TextDataset(tokenized, labels_tensor)
335341

336342
def _train_test_split(
337343
self,
338344
X: list[str],
339-
y: list[str] | list[list[str]],
345+
y: LabelType,
340346
test_size: float,
341347
) -> tuple[list[str], list[str], LabelType, LabelType]:
342348
"""
@@ -384,12 +390,18 @@ def to_pipeline(self) -> StaticModelPipeline:
384390

385391

386392
class _ClassifierLightningModule(pl.LightningModule):
387-
def __init__(self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
393+
def __init__(
394+
self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None
395+
) -> None:
388396
"""Initialize the LightningModule."""
389397
super().__init__()
390398
self.model = model
391399
self.learning_rate = learning_rate
392-
self.loss_function = nn.CrossEntropyLoss(weight=class_weight) if not model.multilabel else nn.BCEWithLogitsLoss(pos_weight=class_weight)
400+
self.loss_function = (
401+
nn.CrossEntropyLoss(weight=class_weight)
402+
if not model.multilabel
403+
else nn.BCEWithLogitsLoss(pos_weight=class_weight)
404+
)
393405

394406
def forward(self, x: torch.Tensor) -> torch.Tensor:
395407
"""Simple forward pass."""
@@ -408,10 +420,12 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
408420
x, y = batch
409421
head_out, _ = self.model(x)
410422
loss = self.loss_function(head_out, y)
423+
424+
accuracy: float
411425
if self.model.multilabel:
412426
preds = (torch.sigmoid(head_out) > 0.5).float()
413427
# Multilabel accuracy is defined as the Jaccard score averaged over samples.
414-
accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples")
428+
accuracy = cast(float, jaccard_score(y.cpu(), preds.cpu(), average="samples"))
415429
else:
416430
accuracy = (head_out.argmax(dim=1) == y).float().mean()
417431
self.log("val_loss", loss)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def mock_vectors() -> np.ndarray:
9595
@pytest.fixture
9696
def mock_config() -> dict[str, Any]:
9797
"""Create a mock config."""
98-
return {"some_config": "value", "token_mapping": [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]}
98+
return {"some_config": "value"}
9999

100100

101101
@pytest.fixture(scope="session")

tests/test_trainable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def test_conversion(mock_trained_pipeline: StaticModelForClassification) -> None
9191
"""Test the conversion to numpy."""
9292
staticmodel = mock_trained_pipeline.to_static_model()
9393
with torch.no_grad():
94-
result_1 = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long()).numpy()
95-
result_2 = staticmodel.embedding[[[0, 1], [1, 0]]].mean(0)
94+
result_1 = mock_trained_pipeline._encode(torch.tensor([[1, 2], [2, 1]]).long()).numpy()
95+
result_2 = staticmodel.embedding[[[1, 2], [2, 1]]].mean(0)
9696
result_2 /= np.linalg.norm(result_2, axis=1, keepdims=True)
9797

9898
assert np.allclose(result_1, result_2)
@@ -174,6 +174,7 @@ def test_y_val_none() -> None:
174174
model.fit(X, y, X_val=None, y_val=y_val)
175175
model.fit(X, y, X_val=None, y_val=None)
176176

177+
177178
def test_class_weight() -> None:
178179
"""Test the class weight function."""
179180
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer

0 commit comments

Comments
 (0)