Skip to content

Commit 2bce4e5

Browse files
authored
In the BERT example: apply the attention mask from tokenization during pooling (#3085)
1 parent 1febb7b commit 2bce4e5

File tree

1 file changed

+20
-3
lines changed
  • candle-examples/examples/bert

1 file changed

+20
-3
lines changed

candle-examples/examples/bert/main.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ struct Args {
4949
/// Use tanh based approximation for Gelu instead of erf implementation.
5050
#[arg(long, default_value = "false")]
5151
approximate_gelu: bool,
52+
53+
/// Include padding token embeddings when performing mean pooling. By default, these are masked away.
54+
#[arg(long, default_value = "false")]
55+
include_padding_embeddings: bool,
5256
}
5357

5458
impl Args {
@@ -177,9 +181,22 @@ fn main() -> Result<()> {
177181
println!("running inference on batch {:?}", token_ids.shape());
178182
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
179183
println!("generated embeddings {:?}", embeddings.shape());
180-
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
181-
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
182-
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
184+
let embeddings = if args.include_padding_embeddings {
185+
// Apply avg-pooling by taking the mean embedding value for all
186+
// tokens, including padding. This was the original behavior of this
187+
// example, and we'd like to preserve it for posterity.
188+
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
189+
(embeddings.sum(1)? / (n_tokens as f64))?
190+
} else {
191+
// Apply avg-pooling by taking the mean embedding value for all
192+
// tokens (after applying the attention mask from tokenization).
193+
// This should produce the same numeric result as the
194+
// `sentence_transformers` Python library.
195+
let attention_mask_for_pooling = attention_mask.to_dtype(DTYPE)?.unsqueeze(2)?;
196+
let sum_mask = attention_mask_for_pooling.sum(1)?;
197+
let embeddings = (embeddings.broadcast_mul(&attention_mask_for_pooling)?).sum(1)?;
198+
embeddings.broadcast_div(&sum_mask)?
199+
};
183200
let embeddings = if args.normalize_embeddings {
184201
normalize_l2(&embeddings)?
185202
} else {

0 commit comments

Comments
 (0)