Skip to content

Commit 68491f1

Browse files
authored
chore: Rename PoolingType to PoolingMode (#290)
1 parent 4f6b716 commit 68491f1

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

model2vec/distill/distillation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from transformers.modeling_utils import PreTrainedModel
1212
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1313

14-
from model2vec.distill.inference import PCADimType, PoolingType, create_embeddings, post_process_embeddings
14+
from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings
1515
from model2vec.distill.utils import select_optimal_device
1616
from model2vec.model import StaticModel
1717
from model2vec.quantization import DType, quantize_embeddings
@@ -31,7 +31,7 @@ def distill_from_model(
3131
token_remove_pattern: str | None = r"\[unused\d+\]",
3232
quantize_to: DType | str = DType.Float16,
3333
vocabulary_quantization: int | None = None,
34-
pooling: PoolingType = PoolingType.MEAN,
34+
pooling: PoolingMode = PoolingMode.MEAN,
3535
) -> StaticModel:
3636
"""
3737
Distill a staticmodel from a sentence transformer.
@@ -55,7 +55,7 @@ def distill_from_model(
5555
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
5656
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
5757
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
58-
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
58+
:param pooling: The pooling mode to use for creating embeddings. Can be one of:
5959
'mean' (default): mean over all tokens. Robust and works well in most cases.
6060
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
6161
'first': use the first token's hidden state ([CLS] token in BERT-style models).
@@ -209,7 +209,7 @@ def distill(
209209
trust_remote_code: bool = False,
210210
quantize_to: DType | str = DType.Float16,
211211
vocabulary_quantization: int | None = None,
212-
pooling: PoolingType = PoolingType.MEAN,
212+
pooling: PoolingMode = PoolingMode.MEAN,
213213
) -> StaticModel:
214214
"""
215215
Distill a staticmodel from a sentence transformer.
@@ -232,7 +232,7 @@ def distill(
232232
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
233233
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
234234
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
235-
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
235+
:param pooling: The pooling mode to use for creating embeddings. Can be one of:
236236
'mean' (default): mean over all tokens. Robust and works well in most cases.
237237
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
238238
'first': use the first token's hidden state ([CLS] token in BERT-style models).

model2vec/distill/inference.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
_DEFAULT_BATCH_SIZE = 256
2424

2525

26-
class PoolingType(str, Enum):
26+
class PoolingMode(str, Enum):
2727
"""
28-
Pooling strategies for embedding creation.
28+
Pooling modes for embedding creation.
2929
3030
- MEAN: masked mean over all tokens.
3131
- LAST: last non-padding token (often EOS, common in decoder-style models).
@@ -47,7 +47,7 @@ def create_embeddings(
4747
tokenized: list[list[int]],
4848
device: str,
4949
pad_token_id: int,
50-
pooling: PoolingType = PoolingType.MEAN,
50+
pooling: PoolingMode = PoolingMode.MEAN,
5151
) -> np.ndarray:
5252
"""
5353
Create output embeddings for a bunch of tokens using a pretrained model.
@@ -59,9 +59,9 @@ def create_embeddings(
5959
:param tokenized: All tokenized tokens.
6060
:param device: The torch device to use.
6161
:param pad_token_id: The pad token id. Used to pad sequences.
62-
:param pooling: The pooling strategy to use.
62+
:param pooling: The pooling mode to use.
6363
:return: The output embeddings.
64-
:raises ValueError: If the pooling strategy is unknown.
64+
:raises ValueError: If the pooling mode is unknown.
6565
"""
6666
model = model.to(device).eval() # type: ignore # Transformers error
6767

@@ -97,13 +97,13 @@ def create_embeddings(
9797
# Add token_type_ids for models that support it
9898
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
9999

100-
if pooling == PoolingType.MEAN:
100+
if pooling == PoolingMode.MEAN:
101101
out = _encode_mean_with_model(model, encoded)
102-
elif pooling == PoolingType.LAST:
102+
elif pooling == PoolingMode.LAST:
103103
out = _encode_last_with_model(model, encoded)
104-
elif pooling == PoolingType.FIRST:
104+
elif pooling == PoolingMode.FIRST:
105105
out = _encode_first_with_model(model, encoded)
106-
elif pooling == PoolingType.POOLER:
106+
elif pooling == PoolingMode.POOLER:
107107
out = _encode_pooler_with_model(model, encoded)
108108
else:
109109
raise ValueError(f"Unknown pooling: {pooling}")

tests/test_distillation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from transformers.modeling_utils import PreTrainedModel
1313

1414
from model2vec.distill.distillation import distill, distill_from_model
15-
from model2vec.distill.inference import PoolingType, create_embeddings, post_process_embeddings
15+
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
1616
from model2vec.model import StaticModel
1717
from model2vec.tokenizer import clean_and_create_vocabulary
1818

@@ -260,10 +260,10 @@ def test_clean_and_create_vocabulary(
260260
@pytest.mark.parametrize(
261261
"pooling,with_pooler,expected_rows",
262262
[
263-
(PoolingType.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0)=0
264-
(PoolingType.LAST, False, [2.0, 0.0]), # last of 3: 2; last of 1: 0
265-
(PoolingType.FIRST, False, [0.0, 0.0]), # first position: 0
266-
(PoolingType.POOLER, True, [7.0, 7.0]), # pooler_output used
263+
(PoolingMode.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0)=0
264+
(PoolingMode.LAST, False, [2.0, 0.0]), # last of 3: 2; last of 1: 0
265+
(PoolingMode.FIRST, False, [0.0, 0.0]), # first position: 0
266+
(PoolingMode.POOLER, True, [7.0, 7.0]), # pooler_output used
267267
],
268268
)
269269
def test_pooling_strategies(mock_transformer, pooling, with_pooler, expected_rows) -> None:
@@ -292,5 +292,5 @@ def test_pooler_raises_without_pooler_output(mock_transformer) -> None:
292292
tokenized=tokenized,
293293
device="cpu",
294294
pad_token_id=0,
295-
pooling=PoolingType.POOLER,
295+
pooling=PoolingMode.POOLER,
296296
)

0 commit comments

Comments
 (0)