Skip to content

Commit 0609f5e

Browse files
authored
feat: Add configurable pooling for distillation (#288)
* Added configurable pooling for distillation * Simplified pooling code, generalized pooling options * Updated docstrings * Updated docstrings * Updated docstrings
1 parent e10118e commit 0609f5e

File tree

4 files changed

+200
-47
lines changed

4 files changed

+200
-47
lines changed

model2vec/distill/distillation.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from transformers.modeling_utils import PreTrainedModel
1212
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1313

14-
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
14+
from model2vec.distill.inference import PCADimType, PoolingType, create_embeddings, post_process_embeddings
1515
from model2vec.distill.utils import select_optimal_device
1616
from model2vec.model import StaticModel
1717
from model2vec.quantization import DType, quantize_embeddings
@@ -33,6 +33,7 @@ def distill_from_model(
3333
quantize_to: DType | str = DType.Float16,
3434
use_subword: bool | None = None,
3535
vocabulary_quantization: int | None = None,
36+
pooling: PoolingType = PoolingType.MEAN,
3637
) -> StaticModel:
3738
"""
3839
Distill a staticmodel from a sentence transformer.
@@ -59,7 +60,12 @@ def distill_from_model(
5960
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
6061
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
6162
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
62-
:return: A StaticModel
63+
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
64+
'mean' (default): mean over all tokens. Robust and works well in most cases.
65+
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
66+
'first': use the first token's hidden state ([CLS] token in BERT-style models).
67+
'pooler': use the pooler output (if available). This is often a non-linear projection of the [CLS] token.
68+
:return: A StaticModel.
6369
:raises: ValueError if the vocabulary is empty after preprocessing.
6470
6571
"""
@@ -114,7 +120,11 @@ def distill_from_model(
114120

115121
# Create the embeddings
116122
embeddings = create_embeddings(
117-
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
123+
tokenized=token_ids,
124+
model=model,
125+
device=device,
126+
pad_token_id=tokenizer.get_vocab()[pad_token],
127+
pooling=pooling,
118128
)
119129

120130
if vocabulary_quantization is not None:
@@ -142,6 +152,7 @@ def distill_from_model(
142152
"hidden_dim": embeddings.shape[1],
143153
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
144154
"normalize": True,
155+
"pooling": pooling,
145156
}
146157

147158
if os.path.exists(model_name):
@@ -226,6 +237,7 @@ def distill(
226237
quantize_to: DType | str = DType.Float16,
227238
use_subword: bool | None = None,
228239
vocabulary_quantization: int | None = None,
240+
pooling: PoolingType = PoolingType.MEAN,
229241
) -> StaticModel:
230242
"""
231243
Distill a staticmodel from a sentence transformer.
@@ -251,6 +263,11 @@ def distill(
251263
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
252264
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
253265
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
266+
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
267+
'mean' (default): mean over all tokens. Robust and works well in most cases.
268+
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
269+
'first': use the first token's hidden state ([CLS] token in BERT-style models).
270+
'pooler': use the pooler output (if available). This is often a non-linear projection of the [CLS] token.
254271
:return: A StaticModel
255272
256273
"""
@@ -272,4 +289,5 @@ def distill(
272289
quantize_to=quantize_to,
273290
use_subword=use_subword,
274291
vocabulary_quantization=vocabulary_quantization,
292+
pooling=pooling,
275293
)

model2vec/distill/inference.py

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import inspect
55
import logging
6+
from enum import Enum
67
from pathlib import Path
7-
from typing import Literal, Protocol, Union
8+
from typing import Literal, Union
89

910
import numpy as np
1011
import torch
@@ -16,23 +17,37 @@
1617

1718
logger = logging.getLogger(__name__)
1819

19-
2020
PathLike = Union[Path, str]
2121
PCADimType = Union[int, None, float, Literal["auto"]]
2222

23-
2423
_DEFAULT_BATCH_SIZE = 256
2524

2625

27-
class ModulewithWeights(Protocol):
28-
weight: torch.nn.Parameter
26+
class PoolingType(str, Enum):
27+
"""
28+
Pooling strategies for embedding creation.
29+
30+
- MEAN: masked mean over all tokens.
31+
- LAST: last non-padding token (often EOS, common in decoder-style models).
32+
- FIRST: first token hidden state (position 0). In BERT-style encoders,
33+
this corresponds to the [CLS] token representation.
34+
- POOLER: use the model's `pooler_output`. In BERT-like models this is
35+
computed as the hidden state at [CLS], passed through a learned
36+
dense layer + activation. Not all models provide this.
37+
"""
38+
39+
MEAN = "mean"
40+
LAST = "last"
41+
FIRST = "first"
42+
POOLER = "pooler"
2943

3044

3145
def create_embeddings(
3246
model: PreTrainedModel,
3347
tokenized: list[list[int]],
3448
device: str,
3549
pad_token_id: int,
50+
pooling: PoolingType = PoolingType.MEAN,
3651
) -> np.ndarray:
3752
"""
3853
Create output embeddings for a bunch of tokens using a pretrained model.
@@ -44,9 +59,11 @@ def create_embeddings(
4459
:param tokenized: All tokenized tokens.
4560
:param device: The torch device to use.
4661
:param pad_token_id: The pad token id. Used to pad sequences.
62+
:param pooling: The pooling strategy to use.
4763
:return: The output embeddings.
64+
:raises ValueError: If the pooling strategy is unknown.
4865
"""
49-
model = model.to(device) # type: ignore # Transformers error
66+
model = model.to(device).eval() # type: ignore # Transformers error
5067

5168
out_weights: np.ndarray
5269
intermediate_weights: list[np.ndarray] = []
@@ -62,56 +79,133 @@ def create_embeddings(
6279
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
6380

6481
for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
65-
batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
82+
batch_list = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
83+
batch = [torch.tensor(x, dtype=torch.long) for x in batch_list]
6684

6785
encoded = {}
6886
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
69-
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
87+
88+
# Create attention mask by using the lengths of each sequence
89+
seq_len = encoded["input_ids"].size(1)
90+
batch_lengths = torch.tensor([len(x) for x in batch_list], device=encoded["input_ids"].device)
91+
token_positions = torch.arange(seq_len, device=encoded["input_ids"].device)
92+
# Mark padding tokens with 0, and non-padding tokens with 1
93+
attention_mask = token_positions.unsqueeze(0) < batch_lengths.unsqueeze(1)
94+
encoded["attention_mask"] = attention_mask.to(dtype=torch.long)
7095

7196
if add_token_type_ids:
97+
# Add token_type_ids for models that support it
7298
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
7399

74-
out = _encode_mean_using_model(model, encoded)
100+
if pooling == PoolingType.MEAN:
101+
out = _encode_mean_with_model(model, encoded)
102+
elif pooling == PoolingType.LAST:
103+
out = _encode_last_with_model(model, encoded)
104+
elif pooling == PoolingType.FIRST:
105+
out = _encode_first_with_model(model, encoded)
106+
elif pooling == PoolingType.POOLER:
107+
out = _encode_pooler_with_model(model, encoded)
108+
else:
109+
raise ValueError(f"Unknown pooling: {pooling}")
110+
75111
intermediate_weights.extend(out.numpy())
76112
pbar.update(len(batch))
77113

78114
# Sort the output back to the original order
79115
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
80116
out_weights = np.stack(intermediate_weights)
81-
82117
out_weights = np.nan_to_num(out_weights)
83118

84119
return out_weights
85120

86121

87-
@torch.no_grad()
88-
def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
122+
def _encode_with_model(
123+
model: PreTrainedModel, encodings: dict[str, torch.Tensor]
124+
) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor]]:
89125
"""
90-
Encode a batch of tokens using a model.
91-
92-
Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
93-
So detection of these is necessary.
126+
Move inputs to the model device, run a forward pass, and standardize dtypes.
94127
95128
:param model: The model to use.
96129
:param encodings: The encoded tokens to turn into features.
97-
:return: The mean of the output for each token.
130+
:return: a tuple consisting of:
131+
- hidden: last_hidden_state
132+
- pooler: pooler_output if present, else None
133+
- encodings_on_device: the device-moved encodings (for masks)
98134
"""
99-
encodings = {k: v.to(model.device) for k, v in encodings.items()}
100-
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101-
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # False positive
135+
encodings_on_device = {k: v.to(model.device) for k, v in encodings.items()}
136+
outputs: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings_on_device)
137+
hidden: torch.Tensor = outputs.last_hidden_state # type: ignore # False positive
102138
# NOTE: If the dtype is bfloat 16, we convert to float32,
103139
# because numpy does not suport bfloat16
104140
# See here: https://github.com/numpy/numpy/issues/19808
105-
if out.dtype == torch.bfloat16:
106-
out = out.float()
141+
if hidden.dtype == torch.bfloat16:
142+
hidden = hidden.float()
143+
pooler = getattr(outputs, "pooler_output", None)
144+
if pooler is not None and pooler.dtype == torch.bfloat16:
145+
pooler = pooler.float()
146+
return hidden, pooler, encodings_on_device
147+
107148

149+
@torch.inference_mode()
150+
def _encode_mean_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
151+
"""
152+
Encode a batch of tokens using mean pooling.
153+
154+
:param model: The model to use.
155+
:param encodings: The encoded tokens to turn into features.
156+
:return: The mean of the output for each token.
157+
"""
158+
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
108159
# Take the mean by averaging over the attention mask.
109-
mask = encodings["attention_mask"].cpu().float()
110-
mask /= mask.sum(1)[:, None]
160+
mask = encodings_on_device["attention_mask"].cpu().float()
161+
lengths = mask.sum(1, keepdim=True).clamp_min_(1.0)
162+
mask = mask / lengths
163+
return torch.bmm(mask.to(hidden.device)[:, None, :], hidden).squeeze(1).cpu()
164+
111165

112-
result = torch.bmm(mask[:, None, :].float(), out).squeeze(1)
166+
@torch.inference_mode()
167+
def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
168+
"""
169+
Encode a batch of tokens using last token pooling.
170+
171+
:param model: The model to use.
172+
:param encodings: The encoded tokens to turn into features.
173+
:return: The last hidden state for each token.
174+
"""
175+
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
176+
mask = encodings_on_device["attention_mask"].bool()
177+
last_idx = (mask.sum(dim=1) - 1).clamp_min(0).long()
178+
batch_indices = torch.arange(hidden.size(0), device=hidden.device)
179+
return hidden[batch_indices, last_idx, :].cpu()
180+
181+
182+
@torch.inference_mode()
183+
def _encode_first_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
184+
"""
185+
Encode a batch of tokens using first token (CLS) pooling.
186+
187+
:param model: The model to use.
188+
:param encodings: The encoded tokens to turn into features.
189+
:return: The first token representation for each token.
190+
"""
191+
hidden, _, _ = _encode_with_model(model, encodings)
192+
return hidden[:, 0, :].cpu()
113193

114-
return result
194+
195+
@torch.inference_mode()
196+
def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
197+
"""
198+
Encode a batch of tokens using pooler output.
199+
200+
:param model: The model to use.
201+
:param encodings: The encoded tokens to turn into features.
202+
:return: The pooler output for each token.
203+
:raises ValueError: If the model does not return pooler_output.
204+
"""
205+
_, pooler, _ = _encode_with_model(model, encodings)
206+
if pooler is None:
207+
raise ValueError("POOLER pooling requested, but model did not return pooler_output.")
208+
return pooler.cpu()
115209

116210

117211
def post_process_embeddings(

tests/conftest.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,29 +59,30 @@ def mock_transformer() -> PreTrainedModel:
5959
"""Create a mock transformer model."""
6060

6161
class MockPreTrainedModel:
62-
def __init__(self) -> None:
62+
def __init__(self, dim: int = 768, with_pooler: bool = True, pooler_value: float = 7.0) -> None:
6363
self.device = "cpu"
6464
self.name_or_path = "mock-model"
65+
self.dim = dim
66+
self.with_pooler = with_pooler
67+
self.pooler_value = pooler_value
6568

6669
def to(self, device: str) -> MockPreTrainedModel:
6770
self.device = device
6871
return self
6972

73+
def eval(self) -> MockPreTrainedModel:
74+
return self
75+
7076
def forward(self, *args: Any, **kwargs: Any) -> Any:
71-
# Simulate a last_hidden_state output for a transformer model
72-
batch_size, seq_length = kwargs["input_ids"].shape
73-
# Return a tensor of shape (batch_size, seq_length, 768)
74-
return type(
75-
"BaseModelOutputWithPoolingAndCrossAttentions",
76-
(object,),
77-
{
78-
"last_hidden_state": torch.rand(batch_size, seq_length, 768) # Simulate 768 hidden units
79-
},
80-
)
81-
82-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
83-
# Simply call the forward method to simulate the same behavior as transformers models
84-
return self.forward(*args, **kwargs)
77+
input_ids = kwargs["input_ids"]
78+
B, T = input_ids.shape
79+
hidden = torch.arange(T, dtype=torch.float32, device=self.device).repeat(B, self.dim, 1).transpose(1, 2)
80+
out = {"last_hidden_state": hidden}
81+
if self.with_pooler:
82+
out["pooler_output"] = torch.full((B, self.dim), self.pooler_value, device=self.device)
83+
return type("BaseModelOutputWithPoolingAndCrossAttentions", (object,), out)()
84+
85+
__call__ = forward
8586

8687
return cast(PreTrainedModel, MockPreTrainedModel())
8788

0 commit comments

Comments
 (0)