@@ -49,6 +49,10 @@ struct Args {
4949 /// Use tanh based approximation for Gelu instead of erf implementation.
5050 #[ arg( long, default_value = "false" ) ]
5151 approximate_gelu : bool ,
52+
53+ /// Include padding token embeddings when performing mean pooling. By default, these are masked away.
54+ #[ arg( long, default_value = "false" ) ]
55+ include_padding_embeddings : bool ,
5256}
5357
5458impl Args {
@@ -177,9 +181,22 @@ fn main() -> Result<()> {
177181 println ! ( "running inference on batch {:?}" , token_ids. shape( ) ) ;
178182 let embeddings = model. forward ( & token_ids, & token_type_ids, Some ( & attention_mask) ) ?;
179183 println ! ( "generated embeddings {:?}" , embeddings. shape( ) ) ;
180- // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
181- let ( _n_sentence, n_tokens, _hidden_size) = embeddings. dims3 ( ) ?;
182- let embeddings = ( embeddings. sum ( 1 ) ? / ( n_tokens as f64 ) ) ?;
184+ let embeddings = if args. include_padding_embeddings {
185+ // Apply avg-pooling by taking the mean embedding value for all
186+ // tokens, including padding. This was the original behavior of this
187+ // example, and we'd like to preserve it for posterity.
188+ let ( _n_sentence, n_tokens, _hidden_size) = embeddings. dims3 ( ) ?;
189+ ( embeddings. sum ( 1 ) ? / ( n_tokens as f64 ) ) ?
190+ } else {
191+ // Apply avg-pooling by taking the mean embedding value for all
192+ // tokens (after applying the attention mask from tokenization).
193+ // This should produce the same numeric result as the
194+ // `sentence_transformers` Python library.
195+ let attention_mask_for_pooling = attention_mask. to_dtype ( DTYPE ) ?. unsqueeze ( 2 ) ?;
196+ let sum_mask = attention_mask_for_pooling. sum ( 1 ) ?;
197+ let embeddings = ( embeddings. broadcast_mul ( & attention_mask_for_pooling) ?) . sum ( 1 ) ?;
198+ embeddings. broadcast_div ( & sum_mask) ?
199+ } ;
183200 let embeddings = if args. normalize_embeddings {
184201 normalize_l2 ( & embeddings) ?
185202 } else {
0 commit comments