Skip to content

Commit 9ab2f2c

Browse files
feat: add splade pooling to Bert (#187)
1 parent b7be6c0 commit 9ab2f2c

File tree

8 files changed

+174
-75
lines changed

8 files changed

+174
-75
lines changed

.github/workflows/build_75.yaml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
- 'main'
88
tags:
99
- 'v*'
10-
pull_request:
11-
paths:
12-
- ".github/workflows/build_75.yaml"
13-
# - "integration-tests/**"
14-
- "backends/**"
15-
- "core/**"
16-
- "router/**"
17-
- "Cargo.lock"
18-
- "rust-toolchain.toml"
19-
- "Dockerfile"
20-
branches:
21-
- 'main'
2210

2311
jobs:
2412
build-and-push-image:

.github/workflows/build_86.yaml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
- 'main'
88
tags:
99
- 'v*'
10-
pull_request:
11-
paths:
12-
- ".github/workflows/build.yaml"
13-
# - "integration-tests/**"
14-
- "backends/**"
15-
- "core/**"
16-
- "router/**"
17-
- "Cargo.lock"
18-
- "rust-toolchain.toml"
19-
- "Dockerfile"
20-
branches:
21-
- 'main'
2210

2311
jobs:
2412
build-and-push-image:

.github/workflows/build_89.yaml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
- 'main'
88
tags:
99
- 'v*'
10-
pull_request:
11-
paths:
12-
- ".github/workflows/build.yaml"
13-
# - "integration-tests/**"
14-
- "backends/**"
15-
- "core/**"
16-
- "router/**"
17-
- "Cargo.lock"
18-
- "rust-toolchain.toml"
19-
- "Dockerfile"
20-
branches:
21-
- 'main'
2210

2311
jobs:
2412
build-and-push-image:

.github/workflows/build_90.yaml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
- 'main'
88
tags:
99
- 'v*'
10-
pull_request:
11-
paths:
12-
- ".github/workflows/build.yaml"
13-
# - "integration-tests/**"
14-
- "backends/**"
15-
- "core/**"
16-
- "router/**"
17-
- "Cargo.lock"
18-
- "rust-toolchain.toml"
19-
- "Dockerfile"
20-
branches:
21-
- 'main'
2210

2311
jobs:
2412
build-and-push-image:

backends/candle/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ impl CandleBackend {
194194
.to_lowercase()
195195
== "true"
196196
{
197-
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
197+
tracing::info!("Starting FlashDistilBertModel model on {:?}", device);
198198
Ok(Box::new(
199199
FlashDistilBertModel::load(vb, &config, model_type).s()?,
200200
))

backends/candle/src/models/bert.rs

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,95 @@ impl ClassificationHead for RobertaClassificationHead {
440440
}
441441
}
442442

443+
#[derive(Debug)]
444+
pub struct BertSpladeHead {
445+
transform: Linear,
446+
transform_layer_norm: LayerNorm,
447+
decoder: Linear,
448+
span: tracing::Span,
449+
}
450+
451+
impl BertSpladeHead {
452+
pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
453+
let vb = vb.pp("cls.predictions");
454+
let transform_weight = vb
455+
.pp("transform.dense")
456+
.get((config.hidden_size, config.hidden_size), "weight")?;
457+
let transform_bias = vb.pp("transform.dense").get(config.hidden_size, "bias")?;
458+
let transform = Linear::new(
459+
transform_weight,
460+
Some(transform_bias),
461+
Some(config.hidden_act.clone()),
462+
);
463+
464+
let transform_layer_norm = LayerNorm::load(
465+
vb.pp("transform.LayerNorm"),
466+
config.hidden_size,
467+
config.layer_norm_eps as f32,
468+
)?;
469+
470+
let decoder_weight = vb
471+
.pp("decoder")
472+
.get((config.vocab_size, config.hidden_size), "weight")?;
473+
let decoder_bias = vb.get(config.vocab_size, "bias")?;
474+
let decoder = Linear::new(decoder_weight, Some(decoder_bias), Some(HiddenAct::Relu));
475+
476+
Ok(Self {
477+
transform,
478+
transform_layer_norm,
479+
decoder,
480+
span: tracing::span!(tracing::Level::TRACE, "splade"),
481+
})
482+
}
483+
484+
pub(crate) fn load_roberta(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
485+
let vb = vb.pp("lm_head");
486+
let transform_weight = vb
487+
.pp("dense")
488+
.get((config.hidden_size, config.hidden_size), "weight")?;
489+
let transform_bias = vb.pp("dense").get(config.hidden_size, "bias")?;
490+
let transform = Linear::new(
491+
transform_weight,
492+
Some(transform_bias),
493+
Some(HiddenAct::Gelu),
494+
);
495+
496+
let transform_layer_norm = LayerNorm::load(
497+
vb.pp("layer_norm"),
498+
config.hidden_size,
499+
config.layer_norm_eps as f32,
500+
)?;
501+
502+
let decoder_weight = vb
503+
.pp("decoder")
504+
.get((config.vocab_size, config.hidden_size), "weight")?;
505+
let decoder_bias = vb.get(config.vocab_size, "bias")?;
506+
let decoder = Linear::new(decoder_weight, Some(decoder_bias), Some(HiddenAct::Relu));
507+
508+
Ok(Self {
509+
transform,
510+
transform_layer_norm,
511+
decoder,
512+
span: tracing::span!(tracing::Level::TRACE, "splade"),
513+
})
514+
}
515+
516+
pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
517+
let _enter = self.span.enter();
518+
519+
let hidden_states = self.transform.forward(hidden_states)?;
520+
let hidden_states = self.transform_layer_norm.forward(&hidden_states, None)?;
521+
let hidden_states = self.decoder.forward(&hidden_states)?;
522+
(1.0 + hidden_states)?.log()
523+
}
524+
}
525+
443526
pub struct BertModel {
444527
embeddings: BertEmbeddings,
445528
encoder: BertEncoder,
446529
pool: Pool,
447530
classifier: Option<Box<dyn ClassificationHead + Send>>,
531+
splade: Option<BertSpladeHead>,
448532

449533
num_attention_heads: usize,
450534

@@ -461,20 +545,22 @@ impl BertModel {
461545
candle::bail!("Bert only supports absolute position embeddings")
462546
}
463547

464-
let (pool, classifier) = match model_type {
548+
let (pool, classifier, splade) = match model_type {
465549
// Classifier models always use CLS pooling
466550
ModelType::Classifier => {
467551
let pool = Pool::Cls;
468552

469553
let classifier: Box<dyn ClassificationHead + Send> =
470554
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
471-
(pool, Some(classifier))
555+
(pool, Some(classifier), None)
472556
}
473557
ModelType::Embedding(pool) => {
474-
if pool == Pool::Splade {
475-
candle::bail!("`splade` is not supported for Nomic")
476-
}
477-
(pool, None)
558+
let splade = if pool == Pool::Splade {
559+
Some(BertSpladeHead::load(vb.clone(), config)?)
560+
} else {
561+
None
562+
};
563+
(pool, None, splade)
478564
}
479565
};
480566

@@ -500,6 +586,7 @@ impl BertModel {
500586
encoder,
501587
pool,
502588
classifier,
589+
splade,
503590
num_attention_heads: config.num_attention_heads,
504591
device: vb.device().clone(),
505592
dtype: vb.dtype(),
@@ -517,17 +604,24 @@ impl BertModel {
517604
candle::bail!("Bert only supports absolute position embeddings")
518605
}
519606

520-
let (pool, classifier) = match model_type {
607+
let (pool, classifier, splade) = match model_type {
521608
// Classifier models always use CLS pooling
522609
ModelType::Classifier => {
523610
let pool = Pool::Cls;
524611

525612
let classifier: Box<dyn ClassificationHead + Send> = Box::new(
526613
RobertaClassificationHead::load(vb.pp("classifier"), config)?,
527614
);
528-
(pool, Some(classifier))
615+
(pool, Some(classifier), None)
616+
}
617+
ModelType::Embedding(pool) => {
618+
let splade = if pool == Pool::Splade {
619+
Some(BertSpladeHead::load_roberta(vb.clone(), config)?)
620+
} else {
621+
None
622+
};
623+
(pool, None, splade)
529624
}
530-
ModelType::Embedding(pool) => (pool, None),
531625
};
532626

533627
let (embeddings, encoder) = match (
@@ -562,6 +656,7 @@ impl BertModel {
562656
encoder,
563657
pool,
564658
classifier,
659+
splade,
565660
num_attention_heads: config.num_attention_heads,
566661
device: vb.device().clone(),
567662
dtype: vb.dtype(),
@@ -730,7 +825,25 @@ impl BertModel {
730825

731826
(outputs.sum(1)?.broadcast_div(&input_lengths))?
732827
}
733-
Pool::Splade => unreachable!(),
828+
Pool::Splade => {
829+
// Unwrap is safe here
830+
let splade_head = self.splade.as_ref().unwrap();
831+
let mut relu_log = splade_head.forward(&outputs)?;
832+
833+
if let Some(ref attention_mask) = attention_mask {
834+
let mut attention_mask = attention_mask.clone();
835+
836+
if let Some(pooled_indices) = pooled_indices {
837+
// Select values in the batch
838+
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
839+
};
840+
841+
// Mask padded values
842+
relu_log = relu_log.broadcast_mul(&attention_mask)?;
843+
}
844+
845+
relu_log.max(1)?
846+
}
734847
};
735848
Some(pooled_embeddings)
736849
} else {

0 commit comments

Comments
 (0)