Skip to content

Commit 9fe7e33

Browse files
committed
fixes
1 parent 1a8569a commit 9fe7e33

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

model2vec/train/base.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class FinetunableStaticModel(nn.Module):
19-
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
19+
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0, token_mapping: list[int] | None = None) -> None:
2020
"""
2121
Initialize a trainable StaticModel from a StaticModel.
2222
@@ -38,14 +38,19 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int
3838
)
3939
self.vectors = vectors.float()
4040

41+
if token_mapping is not None:
42+
self.token_mapping = torch.tensor(token_mapping, dtype=torch.int64)
43+
else:
44+
self.token_mapping = torch.arange(len(vectors), dtype=torch.int64)
45+
self.token_mapping = nn.Parameter(self.token_mapping, requires_grad=False)
4146
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
4247
self.head = self.construct_head()
4348
self.w = self.construct_weights()
4449
self.tokenizer = tokenizer
4550

4651
def construct_weights(self) -> nn.Parameter:
4752
"""Construct the weights for the model."""
48-
weights = torch.zeros(len(self.vectors))
53+
weights = torch.zeros(len(self.token_mapping))
4954
weights[self.pad_id] = -10_000
5055
return nn.Parameter(weights)
5156

@@ -66,11 +71,16 @@ def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int
6671
"""Load the model from a static model."""
6772
model.embedding = np.nan_to_num(model.embedding)
6873
embeddings_converted = torch.from_numpy(model.embedding)
74+
if model.token_mapping is not None:
75+
token_mapping = [i for _, i in sorted(model.token_mapping.items(), key=lambda x: x[0])]
76+
else:
77+
token_mapping = None
6978
return cls(
7079
vectors=embeddings_converted,
7180
pad_id=model.tokenizer.token_to_id("[PAD]"),
7281
out_dim=out_dim,
7382
tokenizer=model.tokenizer,
83+
token_mapping=token_mapping,
7484
**kwargs,
7585
)
7686

@@ -90,7 +100,8 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
90100
w = w * zeros
91101
# Add a small epsilon to avoid division by zero
92102
length = zeros.sum(1) + 1e-16
93-
embedded = self.embeddings(input_ids)
103+
input_ids_embeddings = self.token_mapping[input_ids]
104+
embedded = self.embeddings(input_ids_embeddings)
94105
# Weigh each token
95106
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
96107
# Mean pooling by dividing by the length
@@ -118,16 +129,17 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens
118129
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)
119130

120131
@property
121-
def device(self) -> str:
132+
def device(self) -> torch.device:
122133
"""Get the device of the model."""
123134
return self.embeddings.weight.device
124135

125136
def to_static_model(self) -> StaticModel:
126137
"""Convert the model to a static model."""
127138
emb = self.embeddings.weight.detach().cpu().numpy()
128139
w = torch.sigmoid(self.w).detach().cpu().numpy()
140+
token_mapping = {i: int(token_id) for i, token_id in enumerate(self.token_mapping.tolist())}
129141

130-
return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)
142+
return StaticModel(vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=True, token_mapping=token_mapping)
131143

132144

133145
class TextDataset(Dataset):

model2vec/train/classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
hidden_dim: int = 512,
3939
out_dim: int = 2,
4040
pad_id: int = 0,
41+
token_mapping: list[int] | None = None,
4142
) -> None:
4243
"""Initialize a standard classifier model."""
4344
self.n_layers = n_layers
@@ -46,7 +47,7 @@ def __init__(
4647
self.classes_: list[str] = [str(x) for x in range(out_dim)]
4748
# multilabel flag will be set based on the type of `y` passed to fit.
4849
self.multilabel: bool = False
49-
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)
50+
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer, token_mapping=token_mapping)
5051

5152
@property
5253
def classes(self) -> np.ndarray:

0 commit comments

Comments
 (0)