33
44import inspect
55import logging
6+ from enum import Enum
67from pathlib import Path
7- from typing import Literal , Protocol , Union
8+ from typing import Literal , Union
89
910import numpy as np
1011import torch
1617
1718logger = logging .getLogger (__name__ )
1819
19-
2020PathLike = Union [Path , str ]
2121PCADimType = Union [int , None , float , Literal ["auto" ]]
2222
23-
2423_DEFAULT_BATCH_SIZE = 256
2524
2625
27- class ModulewithWeights (Protocol ):
28- weight : torch .nn .Parameter
26+ class PoolingType (str , Enum ):
27+ """Pooling strategies for embedding creation."""
28+
29+ MEAN = "mean"
30+ LAST = "last"
31+ CLS = "cls"
2932
3033
3134def create_embeddings (
3235 model : PreTrainedModel ,
3336 tokenized : list [list [int ]],
3437 device : str ,
3538 pad_token_id : int ,
39+ pooling : PoolingType = PoolingType .MEAN ,
3640) -> np .ndarray :
3741 """
3842 Create output embeddings for a bunch of tokens using a pretrained model.
@@ -44,9 +48,11 @@ def create_embeddings(
4448 :param tokenized: All tokenized tokens.
4549 :param device: The torch device to use.
4650 :param pad_token_id: The pad token id. Used to pad sequences.
51+ :param pooling: The pooling strategy to use.
4752 :return: The output embeddings.
53+ :raises ValueError: If the pooling strategy is unknown.
4854 """
49- model = model .to (device ) # type: ignore # Transformers error
55+ model = model .to (device ). eval () # type: ignore # Transformers error
5056
5157 out_weights : np .ndarray
5258 intermediate_weights : list [np .ndarray ] = []
@@ -62,56 +68,123 @@ def create_embeddings(
6268 pbar = tqdm (total = len (sorted_tokenized ), desc = "Encoding tokens" , unit = " tokens" )
6369
6470 for batch_idx in range (0 , len (sorted_tokenized ), _DEFAULT_BATCH_SIZE ):
65- batch = [torch .Tensor (x ).long () for x in sorted_tokenized [batch_idx : batch_idx + _DEFAULT_BATCH_SIZE ]]
71+ batch_list = sorted_tokenized [batch_idx : batch_idx + _DEFAULT_BATCH_SIZE ]
72+ batch = [torch .tensor (x , dtype = torch .long ) for x in batch_list ]
6673
6774 encoded = {}
6875 encoded ["input_ids" ] = pad_sequence (batch , batch_first = True , padding_value = pad_token_id )
69- encoded ["attention_mask" ] = encoded ["input_ids" ] != pad_token_id
76+
77+ if pooling == PoolingType .MEAN :
78+ # For mean pooling, mask out padding tokens
79+ encoded ["attention_mask" ] = encoded ["input_ids" ] != pad_token_id
80+ else :
81+ # For "last"/"cls": build mask directly from true lengths to ensure
82+ # the last non-pad token and CLS positions are chosen correctly
83+ seq_len = encoded ["input_ids" ].size (1 )
84+ batch_lengths = torch .tensor ([len (x ) for x in batch_list ], device = encoded ["input_ids" ].device )
85+ token_positions = torch .arange (seq_len , device = encoded ["input_ids" ].device )
86+ encoded ["attention_mask" ] = token_positions .unsqueeze (0 ) < batch_lengths .unsqueeze (1 )
7087
7188 if add_token_type_ids :
89+ # Add token_type_ids for models that support it
7290 encoded ["token_type_ids" ] = torch .zeros_like (encoded ["input_ids" ])
7391
74- out = _encode_mean_using_model (model , encoded )
92+ if pooling == PoolingType .MEAN :
93+ out = _encode_mean_with_model (model , encoded )
94+ elif pooling == PoolingType .LAST :
95+ out = _encode_last_with_model (model , encoded )
96+ elif pooling == PoolingType .CLS :
97+ out = _encode_cls_with_model (model , encoded )
98+ else :
99+ raise ValueError (f"Unknown pooling: { pooling } " )
100+
75101 intermediate_weights .extend (out .numpy ())
76102 pbar .update (len (batch ))
77103
78104 # Sort the output back to the original order
79105 intermediate_weights = [intermediate_weights [i ] for i in np .argsort (sort_order )]
80106 out_weights = np .stack (intermediate_weights )
81-
82107 out_weights = np .nan_to_num (out_weights )
83108
84109 return out_weights
85110
86111
87- @torch .no_grad ()
88- def _encode_mean_using_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
112+ def _encode_with_model (
113+ model : PreTrainedModel , encodings : dict [str , torch .Tensor ]
114+ ) -> tuple [torch .Tensor , torch .Tensor | None , dict [str , torch .Tensor ]]:
89115 """
90- Encode a batch of tokens using a model.
91-
92- Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
93- So detection of these is necessary.
116+ Move inputs to the model device, run a forward pass, and standardize dtypes.
94117
95118 :param model: The model to use.
96119 :param encodings: The encoded tokens to turn into features.
97- :return: The mean of the output for each token.
120+ :return: a tuple consisting of:
121+ - hidden: last_hidden_state
122+ - pooler: pooler_output if present, else None
123+ - encodings_on_device: the device-moved encodings (for masks)
98124 """
99- encodings = {k : v .to (model .device ) for k , v in encodings .items ()}
100- encoded : BaseModelOutputWithPoolingAndCrossAttentions = model (** encodings )
101- out : torch .Tensor = encoded .last_hidden_state . cpu () # type: ignore # False positive
125+ encodings_on_device = {k : v .to (model .device ) for k , v in encodings .items ()}
126+ outputs : BaseModelOutputWithPoolingAndCrossAttentions = model (** encodings_on_device )
127+ hidden : torch .Tensor = outputs .last_hidden_state # type: ignore # False positive
102128 # NOTE: If the dtype is bfloat 16, we convert to float32,
103129 # because numpy does not suport bfloat16
104130 # See here: https://github.com/numpy/numpy/issues/19808
105- if out .dtype == torch .bfloat16 :
106- out = out .float ()
131+ if hidden .dtype == torch .bfloat16 :
132+ hidden = hidden .float ()
133+ pooler = getattr (outputs , "pooler_output" , None )
134+ if pooler is not None and pooler .dtype == torch .bfloat16 :
135+ pooler = pooler .float ()
136+ return hidden , pooler , encodings_on_device
107137
138+
139+ @torch .inference_mode ()
140+ def _encode_mean_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
141+ """
142+ Encode a batch of tokens using mean pooling.
143+
144+ :param model: The model to use.
145+ :param encodings: The encoded tokens to turn into features.
146+ :return: The mean of the output for each token.
147+ """
148+ hidden , _ , encodings_on_device = _encode_with_model (model , encodings )
108149 # Take the mean by averaging over the attention mask.
109- mask = encodings ["attention_mask" ].cpu ().float ()
110- mask /= mask .sum (1 )[:, None ]
150+ mask = encodings_on_device ["attention_mask" ].cpu ().float ()
151+ lengths = mask .sum (1 , keepdim = True ).clamp_min_ (1.0 )
152+ mask = mask / lengths
153+ return torch .bmm (mask .to (hidden .device )[:, None , :], hidden ).squeeze (1 ).cpu ()
111154
112- result = torch .bmm (mask [:, None , :].float (), out ).squeeze (1 )
113155
114- return result
156+ @torch .inference_mode ()
157+ def _encode_last_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
158+ """
159+ Encode a batch of tokens using last token pooling.
160+
161+ :param model: The model to use.
162+ :param encodings: The encoded tokens to turn into features.
163+ :return: The last hidden state for each token.
164+ """
165+ hidden , _ , encodings_on_device = _encode_with_model (model , encodings )
166+ # Get the last hidden state for each token
167+ mask = encodings_on_device ["attention_mask" ].bool ()
168+ last_idx = (mask .sum (dim = 1 ) - 1 ).clamp_min (0 ).long ()
169+ b = torch .arange (hidden .size (0 ), device = hidden .device )
170+ return hidden [b , last_idx , :].cpu ()
171+
172+
173+ @torch .inference_mode ()
174+ def _encode_cls_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
175+ """
176+ Encode a batch of tokens using CLS pooling.
177+
178+ If the model has a pooler_output, use that, otherwise, use the first token's hidden state.
179+
180+ :param model: The model to use.
181+ :param encodings: The encoded tokens to turn into features.
182+ :return: The [CLS] token representation for each token.
183+ """
184+ hidden , pooler , _ = _encode_with_model (model , encodings )
185+ if pooler is not None :
186+ return pooler .cpu ()
187+ return hidden [:, 0 , :].cpu ()
115188
116189
117190def post_process_embeddings (
@@ -124,30 +197,22 @@ def post_process_embeddings(
124197 if pca_dims > embeddings .shape [1 ]:
125198 logger .warning (
126199 f"PCA dimension ({ pca_dims } ) is larger than the number of dimensions in the embeddings ({ embeddings .shape [1 ]} ). "
127- "Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
128- "Applying PCA will probably improve performance, so consider just leaving it."
200+ "Applying PCA, but not reducing dimensionality. If this is not desired, set `pca_dims` to None."
129201 )
130202 pca_dims = embeddings .shape [1 ]
131203 if pca_dims >= embeddings .shape [0 ]:
132204 logger .warning (
133205 f"PCA dimension ({ pca_dims } ) is larger than the number of tokens in the vocabulary ({ embeddings .shape [0 ]} ). Not applying PCA."
134206 )
135207 elif pca_dims <= embeddings .shape [1 ]:
136- if isinstance (pca_dims , float ):
137- logger .info (f"Applying PCA with { pca_dims } explained variance." )
138- else :
139- logger .info (f"Applying PCA with n_components { pca_dims } " )
140-
141208 orig_dims = embeddings .shape [1 ]
142209 p = PCA (n_components = pca_dims , svd_solver = "full" )
143210 embeddings = p .fit_transform (embeddings )
144-
145211 if embeddings .shape [1 ] < orig_dims :
146- explained_variance_ratio = np .sum (p .explained_variance_ratio_ )
147- explained_variance = np .sum (p .explained_variance_ )
148- logger .info (f"Reduced dimensionality from { orig_dims } to { embeddings .shape [1 ]} ." )
149- logger .info (f"Explained variance ratio: { explained_variance_ratio :.3f} ." )
150- logger .info (f"Explained variance: { explained_variance :.3f} ." )
212+ logger .info (
213+ f"Reduced dimensionality { orig_dims } -> { embeddings .shape [1 ]} "
214+ f"(explained var ratio: { np .sum (p .explained_variance_ratio_ ):.3f} )."
215+ )
151216
152217 if sif_coefficient is not None :
153218 logger .info ("Estimating word frequencies using Zipf's law, and then applying SIF." )
0 commit comments