Skip to content

Commit 2b8ad5f

Browse files
fix: add_pooling_layer for bert classification (#190)
1 parent e7ae777 commit 2b8ad5f

13 files changed

+138
-19
lines changed

backends/candle/src/models/bert.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ pub trait ClassificationHead {
365365
}
366366

367367
pub struct BertClassificationHead {
368+
pooler: Option<Linear>,
368369
output: Linear,
369370
span: tracing::Span,
370371
}
@@ -376,11 +377,24 @@ impl BertClassificationHead {
376377
Some(id2label) => id2label.len(),
377378
};
378379

379-
let output_weight = vb.get((n_classes, config.hidden_size), "weight")?;
380-
let output_bias = vb.get(n_classes, "bias")?;
380+
let pooler = if let Ok(pooler_weight) = vb
381+
.pp("bert.pooler.dense")
382+
.get((config.hidden_size, config.hidden_size), "weight")
383+
{
384+
let pooler_bias = vb.pp("bert.pooler.dense").get(config.hidden_size, "bias")?;
385+
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
386+
} else {
387+
None
388+
};
389+
390+
let output_weight = vb
391+
.pp("classifier")
392+
.get((n_classes, config.hidden_size), "weight")?;
393+
let output_bias = vb.pp("classifier").get(n_classes, "bias")?;
381394
let output = Linear::new(output_weight, Some(output_bias), None);
382395

383396
Ok(Self {
397+
pooler,
384398
output,
385399
span: tracing::span!(tracing::Level::TRACE, "classifier"),
386400
})
@@ -390,7 +404,14 @@ impl BertClassificationHead {
390404
impl ClassificationHead for BertClassificationHead {
391405
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
392406
let _enter = self.span.enter();
393-
let hidden_states = self.output.forward(hidden_states)?;
407+
408+
let mut hidden_states = hidden_states.clone();
409+
if let Some(pooler) = self.pooler.as_ref() {
410+
hidden_states = pooler.forward(&hidden_states)?;
411+
hidden_states = hidden_states.tanh()?;
412+
}
413+
414+
let hidden_states = self.output.forward(&hidden_states)?;
394415
Ok(hidden_states)
395416
}
396417
}
@@ -551,7 +572,7 @@ impl BertModel {
551572
let pool = Pool::Cls;
552573

553574
let classifier: Box<dyn ClassificationHead + Send> =
554-
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
575+
Box::new(BertClassificationHead::load(vb.clone(), config)?);
555576
(pool, Some(classifier), None)
556577
}
557578
ModelType::Embedding(pool) => {

backends/candle/src/models/flash_bert.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ impl FlashBertModel {
246246
let pool = Pool::Cls;
247247

248248
let classifier: Box<dyn ClassificationHead + Send> =
249-
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
249+
Box::new(BertClassificationHead::load(vb.clone(), config)?);
250250
(pool, Some(classifier), None)
251251
}
252252
ModelType::Embedding(pool) => {

backends/candle/tests/common.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,22 @@ pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>)
6565
(pooled_embeddings, raw_embeddings)
6666
}
6767

68-
pub fn download_artifacts(model_id: &'static str) -> Result<PathBuf> {
68+
pub fn download_artifacts(
69+
model_id: &'static str,
70+
revision: Option<&'static str>,
71+
) -> Result<PathBuf> {
6972
let builder = ApiBuilder::new().with_progress(false);
7073

7174
let api = builder.build().unwrap();
72-
let api_repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
75+
let api_repo = if let Some(revision) = revision {
76+
api.repo(Repo::with_revision(
77+
model_id.to_string(),
78+
RepoType::Model,
79+
revision.to_string(),
80+
))
81+
} else {
82+
api.repo(Repo::new(model_id.to_string(), RepoType::Model))
83+
};
7384

7485
api_repo.get("config.json")?;
7586
api_repo.get("tokenizer.json")?;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
source: backends/candle/tests/test_bert.rs
3+
assertion_line: 211
4+
expression: predictions_single
5+
---
6+
- - 2.8580017
7+
- -2.9722357
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
source: backends/candle/tests/test_flash_bert.rs
3+
expression: predictions_single
4+
---
5+
- - 2.8574219
6+
- -2.9726563

backends/candle/tests/test_bert.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
99
#[test]
1010
#[serial_test::serial]
1111
fn test_mini() -> Result<()> {
12-
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2")?;
12+
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2", None)?;
1313
let tokenizer = load_tokenizer(&model_root)?;
1414

1515
let backend = CandleBackend::new(
@@ -69,7 +69,7 @@ fn test_mini() -> Result<()> {
6969
#[test]
7070
#[serial_test::serial]
7171
fn test_mini_pooled_raw() -> Result<()> {
72-
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2")?;
72+
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2", None)?;
7373
let tokenizer = load_tokenizer(&model_root)?;
7474

7575
let backend = CandleBackend::new(
@@ -139,7 +139,7 @@ fn test_mini_pooled_raw() -> Result<()> {
139139
#[test]
140140
#[serial_test::serial]
141141
fn test_emotions() -> Result<()> {
142-
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions")?;
142+
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
143143
let tokenizer = load_tokenizer(&model_root)?;
144144

145145
let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
@@ -185,3 +185,38 @@ fn test_emotions() -> Result<()> {
185185

186186
Ok(())
187187
}
188+
189+
#[test]
190+
#[serial_test::serial]
191+
fn test_bert_classification() -> Result<()> {
192+
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
193+
let tokenizer = load_tokenizer(&model_root)?;
194+
195+
let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
196+
197+
let input_single = batch(
198+
vec![tokenizer
199+
.encode(
200+
(
201+
"PrimeTime is a timing signoff tool",
202+
"PrimeTime can perform most accurate timing analysis",
203+
),
204+
true,
205+
)
206+
.unwrap()],
207+
[0].to_vec(),
208+
vec![],
209+
);
210+
211+
let predictions: Vec<Vec<f32>> = backend
212+
.predict(input_single)?
213+
.into_iter()
214+
.map(|(_, v)| v)
215+
.collect();
216+
let predictions_single = SnapshotScores::from(predictions);
217+
218+
let matcher = relative_matcher();
219+
insta::assert_yaml_snapshot!("bert_classification_single", predictions_single, &matcher);
220+
221+
Ok(())
222+
}

backends/candle/tests/test_flash_bert.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
1515
any(feature = "flash-attn", feature = "flash-attn-v1")
1616
))]
1717
fn test_flash_mini() -> Result<()> {
18-
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2")?;
18+
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2", None)?;
1919
let tokenizer = load_tokenizer(&model_root)?;
2020

2121
let backend = CandleBackend::new(
@@ -79,7 +79,7 @@ fn test_flash_mini() -> Result<()> {
7979
any(feature = "flash-attn", feature = "flash-attn-v1")
8080
))]
8181
fn test_flash_mini_pooled_raw() -> Result<()> {
82-
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2")?;
82+
let model_root = download_artifacts("sentence-transformers/all-MiniLM-L6-v2", None)?;
8383
let tokenizer = load_tokenizer(&model_root)?;
8484

8585
let backend = CandleBackend::new(
@@ -153,7 +153,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> {
153153
any(feature = "flash-attn", feature = "flash-attn-v1")
154154
))]
155155
fn test_flash_emotions() -> Result<()> {
156-
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions")?;
156+
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
157157
let tokenizer = load_tokenizer(&model_root)?;
158158

159159
let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
@@ -199,3 +199,42 @@ fn test_flash_emotions() -> Result<()> {
199199

200200
Ok(())
201201
}
202+
203+
#[test]
204+
#[serial_test::serial]
205+
#[cfg(all(
206+
feature = "cuda",
207+
any(feature = "flash-attn", feature = "flash-attn-v1")
208+
))]
209+
fn test_flash_bert_classification() -> Result<()> {
210+
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
211+
let tokenizer = load_tokenizer(&model_root)?;
212+
213+
let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
214+
215+
let input_single = batch(
216+
vec![tokenizer
217+
.encode(
218+
(
219+
"PrimeTime is a timing signoff tool",
220+
"PrimeTime can perform most accurate timing analysis",
221+
),
222+
true,
223+
)
224+
.unwrap()],
225+
[0].to_vec(),
226+
vec![],
227+
);
228+
229+
let predictions: Vec<Vec<f32>> = backend
230+
.predict(input_single)?
231+
.into_iter()
232+
.map(|(_, v)| v)
233+
.collect();
234+
let predictions_single = SnapshotScores::from(predictions);
235+
236+
let matcher = relative_matcher();
237+
insta::assert_yaml_snapshot!("bert_classification_single", predictions_single, &matcher);
238+
239+
Ok(())
240+
}

backends/candle/tests/test_flash_jina.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
1111
#[serial_test::serial]
1212
#[cfg(all(feature = "cuda", feature = "flash-attn"))]
1313
fn test_flash_jina_small() -> Result<()> {
14-
let model_root = download_artifacts("jinaai/jina-embeddings-v2-small-en")?;
14+
let model_root = download_artifacts("jinaai/jina-embeddings-v2-small-en", None)?;
1515
let tokenizer = load_tokenizer(&model_root)?;
1616

1717
let backend = CandleBackend::new(

backends/candle/tests/test_flash_nomic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
1111
#[serial_test::serial]
1212
#[cfg(all(feature = "cuda", feature = "flash-attn"))]
1313
fn test_flash_nomic_small() -> Result<()> {
14-
let model_root = download_artifacts("nomic-ai/nomic-embed-text-v1.5")?;
14+
let model_root = download_artifacts("nomic-ai/nomic-embed-text-v1.5", None)?;
1515
let tokenizer = load_tokenizer(&model_root)?;
1616

1717
let backend = CandleBackend::new(

backends/candle/tests/test_jina.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use text_embeddings_backend_core::{Backend, ModelType, Pool};
88

99
#[test]
1010
fn test_jina_small() -> Result<()> {
11-
let model_root = download_artifacts("jinaai/jina-embeddings-v2-small-en")?;
11+
let model_root = download_artifacts("jinaai/jina-embeddings-v2-small-en", None)?;
1212
let tokenizer = load_tokenizer(&model_root)?;
1313

1414
let backend = CandleBackend::new(

0 commit comments

Comments
 (0)