Skip to content

Commit 1e3b58e

Browse files
committed
Added configurable pooling for distillation
1 parent e10118e commit 1e3b58e

File tree

4 files changed

+161
-59
lines changed

4 files changed

+161
-59
lines changed

model2vec/distill/distillation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from transformers.modeling_utils import PreTrainedModel
1212
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1313

14-
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
14+
from model2vec.distill.inference import PCADimType, PoolingType, create_embeddings, post_process_embeddings
1515
from model2vec.distill.utils import select_optimal_device
1616
from model2vec.model import StaticModel
1717
from model2vec.quantization import DType, quantize_embeddings
@@ -33,6 +33,7 @@ def distill_from_model(
3333
quantize_to: DType | str = DType.Float16,
3434
use_subword: bool | None = None,
3535
vocabulary_quantization: int | None = None,
36+
pooling: PoolingType = PoolingType.MEAN,
3637
) -> StaticModel:
3738
"""
3839
Distill a staticmodel from a sentence transformer.
@@ -59,6 +60,7 @@ def distill_from_model(
5960
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
6061
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
6162
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
63+
:param pooling: The pooling strategy to use for creating embeddings. Can be one of "mean", "last", or "cls".
6264
:return: A StaticModel
6365
:raises: ValueError if the vocabulary is empty after preprocessing.
6466
@@ -114,7 +116,11 @@ def distill_from_model(
114116

115117
# Create the embeddings
116118
embeddings = create_embeddings(
117-
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
119+
tokenized=token_ids,
120+
model=model,
121+
device=device,
122+
pad_token_id=tokenizer.get_vocab()[pad_token],
123+
pooling=pooling,
118124
)
119125

120126
if vocabulary_quantization is not None:
@@ -142,6 +148,7 @@ def distill_from_model(
142148
"hidden_dim": embeddings.shape[1],
143149
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
144150
"normalize": True,
151+
"pooling": pooling,
145152
}
146153

147154
if os.path.exists(model_name):
@@ -226,6 +233,7 @@ def distill(
226233
quantize_to: DType | str = DType.Float16,
227234
use_subword: bool | None = None,
228235
vocabulary_quantization: int | None = None,
236+
pooling: PoolingType = PoolingType.MEAN,
229237
) -> StaticModel:
230238
"""
231239
Distill a staticmodel from a sentence transformer.
@@ -251,6 +259,7 @@ def distill(
251259
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
252260
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
253261
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
262+
:param pooling: The pooling strategy to use for creating embeddings. Can be one of "mean", "last", or "cls".
254263
:return: A StaticModel
255264
256265
"""
@@ -272,4 +281,5 @@ def distill(
272281
quantize_to=quantize_to,
273282
use_subword=use_subword,
274283
vocabulary_quantization=vocabulary_quantization,
284+
pooling=pooling,
275285
)

model2vec/distill/inference.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import inspect
55
import logging
6+
from enum import Enum
67
from pathlib import Path
7-
from typing import Literal, Protocol, Union
8+
from typing import Literal, Union
89

910
import numpy as np
1011
import torch
@@ -16,23 +17,26 @@
1617

1718
logger = logging.getLogger(__name__)
1819

19-
2020
PathLike = Union[Path, str]
2121
PCADimType = Union[int, None, float, Literal["auto"]]
2222

23-
2423
_DEFAULT_BATCH_SIZE = 256
2524

2625

27-
class ModulewithWeights(Protocol):
28-
weight: torch.nn.Parameter
26+
class PoolingType(str, Enum):
27+
"""Pooling strategies for embedding creation."""
28+
29+
MEAN = "mean"
30+
LAST = "last"
31+
CLS = "cls"
2932

3033

3134
def create_embeddings(
3235
model: PreTrainedModel,
3336
tokenized: list[list[int]],
3437
device: str,
3538
pad_token_id: int,
39+
pooling: PoolingType = PoolingType.MEAN,
3640
) -> np.ndarray:
3741
"""
3842
Create output embeddings for a bunch of tokens using a pretrained model.
@@ -44,9 +48,11 @@ def create_embeddings(
4448
:param tokenized: All tokenized tokens.
4549
:param device: The torch device to use.
4650
:param pad_token_id: The pad token id. Used to pad sequences.
51+
:param pooling: The pooling strategy to use.
4752
:return: The output embeddings.
53+
:raises ValueError: If the pooling strategy is unknown.
4854
"""
49-
model = model.to(device) # type: ignore # Transformers error
55+
model = model.to(device).eval() # type: ignore # Transformers error
5056

5157
out_weights: np.ndarray
5258
intermediate_weights: list[np.ndarray] = []
@@ -62,56 +68,123 @@ def create_embeddings(
6268
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
6369

6470
for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
65-
batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
71+
batch_list = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
72+
batch = [torch.tensor(x, dtype=torch.long) for x in batch_list]
6673

6774
encoded = {}
6875
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
69-
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
76+
77+
if pooling == PoolingType.MEAN:
78+
# For mean pooling, mask out padding tokens
79+
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
80+
else:
81+
# For "last"/"cls": build mask directly from true lengths to ensure
82+
# the last non-pad token and CLS positions are chosen correctly
83+
seq_len = encoded["input_ids"].size(1)
84+
batch_lengths = torch.tensor([len(x) for x in batch_list], device=encoded["input_ids"].device)
85+
token_positions = torch.arange(seq_len, device=encoded["input_ids"].device)
86+
encoded["attention_mask"] = token_positions.unsqueeze(0) < batch_lengths.unsqueeze(1)
7087

7188
if add_token_type_ids:
89+
# Add token_type_ids for models that support it
7290
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
7391

74-
out = _encode_mean_using_model(model, encoded)
92+
if pooling == PoolingType.MEAN:
93+
out = _encode_mean_with_model(model, encoded)
94+
elif pooling == PoolingType.LAST:
95+
out = _encode_last_with_model(model, encoded)
96+
elif pooling == PoolingType.CLS:
97+
out = _encode_cls_with_model(model, encoded)
98+
else:
99+
raise ValueError(f"Unknown pooling: {pooling}")
100+
75101
intermediate_weights.extend(out.numpy())
76102
pbar.update(len(batch))
77103

78104
# Sort the output back to the original order
79105
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
80106
out_weights = np.stack(intermediate_weights)
81-
82107
out_weights = np.nan_to_num(out_weights)
83108

84109
return out_weights
85110

86111

87-
@torch.no_grad()
88-
def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
112+
def _encode_with_model(
113+
model: PreTrainedModel, encodings: dict[str, torch.Tensor]
114+
) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor]]:
89115
"""
90-
Encode a batch of tokens using a model.
91-
92-
Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
93-
So detection of these is necessary.
116+
Move inputs to the model device, run a forward pass, and standardize dtypes.
94117
95118
:param model: The model to use.
96119
:param encodings: The encoded tokens to turn into features.
97-
:return: The mean of the output for each token.
120+
:return: a tuple consisting of:
121+
- hidden: last_hidden_state
122+
- pooler: pooler_output if present, else None
123+
- encodings_on_device: the device-moved encodings (for masks)
98124
"""
99-
encodings = {k: v.to(model.device) for k, v in encodings.items()}
100-
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101-
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # False positive
125+
encodings_on_device = {k: v.to(model.device) for k, v in encodings.items()}
126+
outputs: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings_on_device)
127+
hidden: torch.Tensor = outputs.last_hidden_state # type: ignore # False positive
102128
# NOTE: If the dtype is bfloat 16, we convert to float32,
103129
# because numpy does not suport bfloat16
104130
# See here: https://github.com/numpy/numpy/issues/19808
105-
if out.dtype == torch.bfloat16:
106-
out = out.float()
131+
if hidden.dtype == torch.bfloat16:
132+
hidden = hidden.float()
133+
pooler = getattr(outputs, "pooler_output", None)
134+
if pooler is not None and pooler.dtype == torch.bfloat16:
135+
pooler = pooler.float()
136+
return hidden, pooler, encodings_on_device
107137

138+
139+
@torch.inference_mode()
140+
def _encode_mean_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
141+
"""
142+
Encode a batch of tokens using mean pooling.
143+
144+
:param model: The model to use.
145+
:param encodings: The encoded tokens to turn into features.
146+
:return: The mean of the output for each token.
147+
"""
148+
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
108149
# Take the mean by averaging over the attention mask.
109-
mask = encodings["attention_mask"].cpu().float()
110-
mask /= mask.sum(1)[:, None]
150+
mask = encodings_on_device["attention_mask"].cpu().float()
151+
lengths = mask.sum(1, keepdim=True).clamp_min_(1.0)
152+
mask = mask / lengths
153+
return torch.bmm(mask.to(hidden.device)[:, None, :], hidden).squeeze(1).cpu()
111154

112-
result = torch.bmm(mask[:, None, :].float(), out).squeeze(1)
113155

114-
return result
156+
@torch.inference_mode()
157+
def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
158+
"""
159+
Encode a batch of tokens using last token pooling.
160+
161+
:param model: The model to use.
162+
:param encodings: The encoded tokens to turn into features.
163+
:return: The last hidden state for each token.
164+
"""
165+
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
166+
# Get the last hidden state for each token
167+
mask = encodings_on_device["attention_mask"].bool()
168+
last_idx = (mask.sum(dim=1) - 1).clamp_min(0).long()
169+
b = torch.arange(hidden.size(0), device=hidden.device)
170+
return hidden[b, last_idx, :].cpu()
171+
172+
173+
@torch.inference_mode()
174+
def _encode_cls_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
175+
"""
176+
Encode a batch of tokens using CLS pooling.
177+
178+
If the model has a pooler_output, use that, otherwise, use the first token's hidden state.
179+
180+
:param model: The model to use.
181+
:param encodings: The encoded tokens to turn into features.
182+
:return: The [CLS] token representation for each token.
183+
"""
184+
hidden, pooler, _ = _encode_with_model(model, encodings)
185+
if pooler is not None:
186+
return pooler.cpu()
187+
return hidden[:, 0, :].cpu()
115188

116189

117190
def post_process_embeddings(
@@ -124,30 +197,22 @@ def post_process_embeddings(
124197
if pca_dims > embeddings.shape[1]:
125198
logger.warning(
126199
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
127-
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
128-
"Applying PCA will probably improve performance, so consider just leaving it."
200+
"Applying PCA, but not reducing dimensionality. If this is not desired, set `pca_dims` to None."
129201
)
130202
pca_dims = embeddings.shape[1]
131203
if pca_dims >= embeddings.shape[0]:
132204
logger.warning(
133205
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
134206
)
135207
elif pca_dims <= embeddings.shape[1]:
136-
if isinstance(pca_dims, float):
137-
logger.info(f"Applying PCA with {pca_dims} explained variance.")
138-
else:
139-
logger.info(f"Applying PCA with n_components {pca_dims}")
140-
141208
orig_dims = embeddings.shape[1]
142209
p = PCA(n_components=pca_dims, svd_solver="full")
143210
embeddings = p.fit_transform(embeddings)
144-
145211
if embeddings.shape[1] < orig_dims:
146-
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
147-
explained_variance = np.sum(p.explained_variance_)
148-
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
149-
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
150-
logger.info(f"Explained variance: {explained_variance:.3f}.")
212+
logger.info(
213+
f"Reduced dimensionality {orig_dims} -> {embeddings.shape[1]} "
214+
f"(explained var ratio: {np.sum(p.explained_variance_ratio_):.3f})."
215+
)
151216

152217
if sif_coefficient is not None:
153218
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")

tests/conftest.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,29 +59,30 @@ def mock_transformer() -> PreTrainedModel:
5959
"""Create a mock transformer model."""
6060

6161
class MockPreTrainedModel:
62-
def __init__(self) -> None:
62+
def __init__(self, dim: int = 768, with_pooler: bool = True, pooler_value: float = 7.0) -> None:
6363
self.device = "cpu"
6464
self.name_or_path = "mock-model"
65+
self.dim = dim
66+
self.with_pooler = with_pooler
67+
self.pooler_value = pooler_value
6568

6669
def to(self, device: str) -> MockPreTrainedModel:
6770
self.device = device
6871
return self
6972

73+
def eval(self) -> MockPreTrainedModel:
74+
return self
75+
7076
def forward(self, *args: Any, **kwargs: Any) -> Any:
71-
# Simulate a last_hidden_state output for a transformer model
72-
batch_size, seq_length = kwargs["input_ids"].shape
73-
# Return a tensor of shape (batch_size, seq_length, 768)
74-
return type(
75-
"BaseModelOutputWithPoolingAndCrossAttentions",
76-
(object,),
77-
{
78-
"last_hidden_state": torch.rand(batch_size, seq_length, 768) # Simulate 768 hidden units
79-
},
80-
)
81-
82-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
83-
# Simply call the forward method to simulate the same behavior as transformers models
84-
return self.forward(*args, **kwargs)
77+
input_ids = kwargs["input_ids"]
78+
B, T = input_ids.shape
79+
hidden = torch.arange(T, dtype=torch.float32, device=self.device).repeat(B, self.dim, 1).transpose(1, 2)
80+
out = {"last_hidden_state": hidden}
81+
if self.with_pooler:
82+
out["pooler_output"] = torch.full((B, self.dim), self.pooler_value, device=self.device)
83+
return type("BaseModelOutputWithPoolingAndCrossAttentions", (object,), out)()
84+
85+
__call__ = forward
8586

8687
return cast(PreTrainedModel, MockPreTrainedModel())
8788

tests/test_distillation.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
distill_from_model,
1717
post_process_embeddings,
1818
)
19+
from model2vec.distill.inference import PoolingType, create_embeddings
1920
from model2vec.model import StaticModel
2021

2122
try:
@@ -251,9 +252,9 @@ def test__post_process_embeddings(
251252
sif_weights = (sif_coefficient / (sif_coefficient + proba))[:, None]
252253

253254
expected_zipf_embeddings = original_embeddings * sif_weights
254-
assert np.allclose(
255-
processed_embeddings, expected_zipf_embeddings, rtol=1e-5
256-
), "Zipf weighting not applied correctly"
255+
assert np.allclose(processed_embeddings, expected_zipf_embeddings, rtol=1e-5), (
256+
"Zipf weighting not applied correctly"
257+
)
257258

258259

259260
@pytest.mark.parametrize(
@@ -288,3 +289,28 @@ def test_clean_and_create_vocabulary(
288289
# Ensure the expected warnings contain expected keywords like 'Removed', 'duplicate', or 'empty'
289290
for expected_warning in expected_warnings:
290291
assert any(expected_warning in logged_warning for logged_warning in logged_warnings)
292+
293+
294+
@pytest.mark.parametrize(
295+
"pooling,with_pooler,expected_rows",
296+
[
297+
(PoolingType.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0) = 0
298+
(PoolingType.LAST, False, [2.0, 0.0]), # last of 3: 2; last of 1: 0
299+
(PoolingType.CLS, False, [0.0, 0.0]), # first position: 0
300+
(PoolingType.CLS, True, [7.0, 7.0]), # pooler_output is used
301+
],
302+
)
303+
def test_pooling_strategies(mock_transformer, pooling, with_pooler, expected_rows) -> None:
304+
"""Test different pooling strategies."""
305+
mock_transformer.with_pooler = with_pooler
306+
tokenized = [[10, 11, 12], [20]]
307+
out = create_embeddings(
308+
model=mock_transformer,
309+
tokenized=tokenized,
310+
device="cpu",
311+
pad_token_id=0,
312+
pooling=pooling,
313+
)
314+
dim = out.shape[1]
315+
expected = np.stack([np.full((dim,), v, dtype=np.float32) for v in expected_rows])
316+
assert np.allclose(out, expected, rtol=1e-6, atol=0.0)

0 commit comments

Comments
 (0)