33
44import inspect
55import logging
6+ from enum import Enum
67from pathlib import Path
7- from typing import Literal , Protocol , Union
8+ from typing import Literal , Union
89
910import numpy as np
1011import torch
1617
1718logger = logging .getLogger (__name__ )
1819
19-
2020PathLike = Union [Path , str ]
2121PCADimType = Union [int , None , float , Literal ["auto" ]]
2222
23-
2423_DEFAULT_BATCH_SIZE = 256
2524
2625
27- class ModulewithWeights (Protocol ):
28- weight : torch .nn .Parameter
26+ class PoolingType (str , Enum ):
27+ """
28+ Pooling strategies for embedding creation.
29+
30+ - MEAN: masked mean over all tokens.
31+ - LAST: last non-padding token (often EOS, common in decoder-style models).
32+ - FIRST: first token hidden state (position 0). In BERT-style encoders,
33+ this corresponds to the [CLS] token representation.
34+ - POOLER: use the model's `pooler_output`. In BERT-like models this is
35+ computed as the hidden state at [CLS], passed through a learned
36+ dense layer + activation. Not all models provide this.
37+ """
38+
39+ MEAN = "mean"
40+ LAST = "last"
41+ FIRST = "first"
42+ POOLER = "pooler"
2943
3044
3145def create_embeddings (
3246 model : PreTrainedModel ,
3347 tokenized : list [list [int ]],
3448 device : str ,
3549 pad_token_id : int ,
50+ pooling : PoolingType = PoolingType .MEAN ,
3651) -> np .ndarray :
3752 """
3853 Create output embeddings for a bunch of tokens using a pretrained model.
@@ -44,9 +59,11 @@ def create_embeddings(
4459 :param tokenized: All tokenized tokens.
4560 :param device: The torch device to use.
4661 :param pad_token_id: The pad token id. Used to pad sequences.
62+ :param pooling: The pooling strategy to use.
4763 :return: The output embeddings.
64+ :raises ValueError: If the pooling strategy is unknown.
4865 """
49- model = model .to (device ) # type: ignore # Transformers error
66+ model = model .to (device ). eval () # type: ignore # Transformers error
5067
5168 out_weights : np .ndarray
5269 intermediate_weights : list [np .ndarray ] = []
@@ -62,56 +79,133 @@ def create_embeddings(
6279 pbar = tqdm (total = len (sorted_tokenized ), desc = "Encoding tokens" , unit = " tokens" )
6380
6481 for batch_idx in range (0 , len (sorted_tokenized ), _DEFAULT_BATCH_SIZE ):
65- batch = [torch .Tensor (x ).long () for x in sorted_tokenized [batch_idx : batch_idx + _DEFAULT_BATCH_SIZE ]]
82+ batch_list = sorted_tokenized [batch_idx : batch_idx + _DEFAULT_BATCH_SIZE ]
83+ batch = [torch .tensor (x , dtype = torch .long ) for x in batch_list ]
6684
6785 encoded = {}
6886 encoded ["input_ids" ] = pad_sequence (batch , batch_first = True , padding_value = pad_token_id )
69- encoded ["attention_mask" ] = encoded ["input_ids" ] != pad_token_id
87+
88+ # Create attention mask by using the lengths of each sequence
89+ seq_len = encoded ["input_ids" ].size (1 )
90+ batch_lengths = torch .tensor ([len (x ) for x in batch_list ], device = encoded ["input_ids" ].device )
91+ token_positions = torch .arange (seq_len , device = encoded ["input_ids" ].device )
92+ # Mark padding tokens with 0, and non-padding tokens with 1
93+ attention_mask = token_positions .unsqueeze (0 ) < batch_lengths .unsqueeze (1 )
94+ encoded ["attention_mask" ] = attention_mask .to (dtype = torch .long )
7095
7196 if add_token_type_ids :
97+ # Add token_type_ids for models that support it
7298 encoded ["token_type_ids" ] = torch .zeros_like (encoded ["input_ids" ])
7399
74- out = _encode_mean_using_model (model , encoded )
100+ if pooling == PoolingType .MEAN :
101+ out = _encode_mean_with_model (model , encoded )
102+ elif pooling == PoolingType .LAST :
103+ out = _encode_last_with_model (model , encoded )
104+ elif pooling == PoolingType .FIRST :
105+ out = _encode_first_with_model (model , encoded )
106+ elif pooling == PoolingType .POOLER :
107+ out = _encode_pooler_with_model (model , encoded )
108+ else :
109+ raise ValueError (f"Unknown pooling: { pooling } " )
110+
75111 intermediate_weights .extend (out .numpy ())
76112 pbar .update (len (batch ))
77113
78114 # Sort the output back to the original order
79115 intermediate_weights = [intermediate_weights [i ] for i in np .argsort (sort_order )]
80116 out_weights = np .stack (intermediate_weights )
81-
82117 out_weights = np .nan_to_num (out_weights )
83118
84119 return out_weights
85120
86121
87- @torch .no_grad ()
88- def _encode_mean_using_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
122+ def _encode_with_model (
123+ model : PreTrainedModel , encodings : dict [str , torch .Tensor ]
124+ ) -> tuple [torch .Tensor , torch .Tensor | None , dict [str , torch .Tensor ]]:
89125 """
90- Encode a batch of tokens using a model.
91-
92- Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
93- So detection of these is necessary.
126+ Move inputs to the model device, run a forward pass, and standardize dtypes.
94127
95128 :param model: The model to use.
96129 :param encodings: The encoded tokens to turn into features.
97- :return: The mean of the output for each token.
130+ :return: a tuple consisting of:
131+ - hidden: last_hidden_state
132+ - pooler: pooler_output if present, else None
133+ - encodings_on_device: the device-moved encodings (for masks)
98134 """
99- encodings = {k : v .to (model .device ) for k , v in encodings .items ()}
100- encoded : BaseModelOutputWithPoolingAndCrossAttentions = model (** encodings )
101- out : torch .Tensor = encoded .last_hidden_state . cpu () # type: ignore # False positive
135+ encodings_on_device = {k : v .to (model .device ) for k , v in encodings .items ()}
136+ outputs : BaseModelOutputWithPoolingAndCrossAttentions = model (** encodings_on_device )
137+ hidden : torch .Tensor = outputs .last_hidden_state # type: ignore # False positive
102138 # NOTE: If the dtype is bfloat 16, we convert to float32,
103139 # because numpy does not suport bfloat16
104140 # See here: https://github.com/numpy/numpy/issues/19808
105- if out .dtype == torch .bfloat16 :
106- out = out .float ()
141+ if hidden .dtype == torch .bfloat16 :
142+ hidden = hidden .float ()
143+ pooler = getattr (outputs , "pooler_output" , None )
144+ if pooler is not None and pooler .dtype == torch .bfloat16 :
145+ pooler = pooler .float ()
146+ return hidden , pooler , encodings_on_device
147+
107148
149+ @torch .inference_mode ()
150+ def _encode_mean_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
151+ """
152+ Encode a batch of tokens using mean pooling.
153+
154+ :param model: The model to use.
155+ :param encodings: The encoded tokens to turn into features.
156+ :return: The mean of the output for each token.
157+ """
158+ hidden , _ , encodings_on_device = _encode_with_model (model , encodings )
108159 # Take the mean by averaging over the attention mask.
109- mask = encodings ["attention_mask" ].cpu ().float ()
110- mask /= mask .sum (1 )[:, None ]
160+ mask = encodings_on_device ["attention_mask" ].cpu ().float ()
161+ lengths = mask .sum (1 , keepdim = True ).clamp_min_ (1.0 )
162+ mask = mask / lengths
163+ return torch .bmm (mask .to (hidden .device )[:, None , :], hidden ).squeeze (1 ).cpu ()
164+
111165
112- result = torch .bmm (mask [:, None , :].float (), out ).squeeze (1 )
166+ @torch .inference_mode ()
167+ def _encode_last_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
168+ """
169+ Encode a batch of tokens using last token pooling.
170+
171+ :param model: The model to use.
172+ :param encodings: The encoded tokens to turn into features.
173+ :return: The last hidden state for each token.
174+ """
175+ hidden , _ , encodings_on_device = _encode_with_model (model , encodings )
176+ mask = encodings_on_device ["attention_mask" ].bool ()
177+ last_idx = (mask .sum (dim = 1 ) - 1 ).clamp_min (0 ).long ()
178+ batch_indices = torch .arange (hidden .size (0 ), device = hidden .device )
179+ return hidden [batch_indices , last_idx , :].cpu ()
180+
181+
182+ @torch .inference_mode ()
183+ def _encode_first_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
184+ """
185+ Encode a batch of tokens using first token (CLS) pooling.
186+
187+ :param model: The model to use.
188+ :param encodings: The encoded tokens to turn into features.
189+ :return: The first token representation for each token.
190+ """
191+ hidden , _ , _ = _encode_with_model (model , encodings )
192+ return hidden [:, 0 , :].cpu ()
113193
114- return result
194+
195+ @torch .inference_mode ()
196+ def _encode_pooler_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
197+ """
198+ Encode a batch of tokens using pooler output.
199+
200+ :param model: The model to use.
201+ :param encodings: The encoded tokens to turn into features.
202+ :return: The pooler output for each token.
203+ :raises ValueError: If the model does not return pooler_output.
204+ """
205+ _ , pooler , _ = _encode_with_model (model , encodings )
206+ if pooler is None :
207+ raise ValueError ("POOLER pooling requested, but model did not return pooler_output." )
208+ return pooler .cpu ()
115209
116210
117211def post_process_embeddings (
0 commit comments