Skip to content

Commit 24533a0

Browse files
feat(backend): support classification for bert (#155)
1 parent 6395a7a commit 24533a0

File tree

2 files changed

+71
-64
lines changed

2 files changed

+71
-64
lines changed

backends/candle/src/models/bert.rs

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,49 @@ impl BertEncoder {
359359
}
360360
}
361361

362-
struct BertClassificationHead {
363-
intermediate: Linear,
362+
pub trait ClassificationHead {
363+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
364+
}
365+
366+
pub struct BertClassificationHead {
364367
output: Linear,
365368
span: tracing::Span,
366369
}
367370

368371
impl BertClassificationHead {
369-
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
372+
pub(crate) fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
373+
let n_classes = match &config.id2label {
374+
None => candle::bail!("`id2label` must be set for classifier models"),
375+
Some(id2label) => id2label.len(),
376+
};
377+
378+
let output_weight = vb.get((n_classes, config.hidden_size), "weight")?;
379+
let output_bias = vb.get(n_classes, "bias")?;
380+
let output = Linear::new(output_weight, Some(output_bias), None);
381+
382+
Ok(Self {
383+
output,
384+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
385+
})
386+
}
387+
}
388+
389+
impl ClassificationHead for BertClassificationHead {
390+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
391+
let _enter = self.span.enter();
392+
let hidden_states = self.output.forward(&hidden_states)?;
393+
Ok(hidden_states)
394+
}
395+
}
396+
397+
pub struct RobertaClassificationHead {
398+
intermediate: Linear,
399+
output: Linear,
400+
span: tracing::Span,
401+
}
402+
403+
impl RobertaClassificationHead {
404+
pub(crate) fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
370405
let n_classes = match &config.id2label {
371406
None => candle::bail!("`id2label` must be set for classifier models"),
372407
Some(id2label) => id2label.len(),
@@ -390,8 +425,10 @@ impl BertClassificationHead {
390425
span: tracing::span!(tracing::Level::TRACE, "classifier"),
391426
})
392427
}
428+
}
393429

394-
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
430+
impl ClassificationHead for RobertaClassificationHead {
431+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
395432
let _enter = self.span.enter();
396433

397434
let hidden_states = self.intermediate.forward(hidden_states)?;
@@ -406,7 +443,7 @@ pub struct BertModel {
406443
embeddings: BertEmbeddings,
407444
encoder: BertEncoder,
408445
pool: Pool,
409-
classifier: Option<BertClassificationHead>,
446+
classifier: Option<Box<dyn ClassificationHead + Send>>,
410447

411448
num_attention_heads: usize,
412449

@@ -426,13 +463,18 @@ impl BertModel {
426463
let (pool, classifier) = match model_type {
427464
// Classifier models always use CLS pooling
428465
ModelType::Classifier => {
429-
if config.model_type == Some("bert".to_string()) {
430-
candle::bail!("`classifier` model type is not supported for Bert");
431-
}
432-
(
433-
Pool::Cls,
434-
Some(BertClassificationHead::load(vb.pp("classifier"), config)?),
435-
)
466+
let pool = Pool::Cls;
467+
468+
let classifier: Box<dyn ClassificationHead + Send> =
469+
if config.model_type == Some("bert".to_string()) {
470+
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?)
471+
} else {
472+
Box::new(RobertaClassificationHead::load(
473+
vb.pp("classifier"),
474+
config,
475+
)?)
476+
};
477+
(pool, Some(classifier))
436478
}
437479
ModelType::Embedding(pool) => (pool, None),
438480
};

backends/candle/src/models/flash_bert.rs

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::flash_attn::flash_attn_varlen;
22
use crate::layers::{LayerNorm, Linear};
3-
use crate::models::bert::{Config, PositionEmbeddingType};
3+
use crate::models::bert::{
4+
BertClassificationHead, ClassificationHead, Config, PositionEmbeddingType,
5+
RobertaClassificationHead,
6+
};
47
use crate::models::Model;
58
use candle::{DType, Device, Result, Tensor};
69
use candle_nn::{Embedding, Module, VarBuilder};
@@ -271,54 +274,11 @@ impl BertEncoder {
271274
}
272275
}
273276

274-
struct BertClassificationHead {
275-
intermediate: Linear,
276-
output: Linear,
277-
span: tracing::Span,
278-
}
279-
280-
impl BertClassificationHead {
281-
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
282-
let n_classes = match &config.id2label {
283-
None => candle::bail!("`id2label` must be set for classifier models"),
284-
Some(id2label) => id2label.len(),
285-
};
286-
287-
let intermediate_weight = vb
288-
.pp("dense")
289-
.get((config.hidden_size, config.hidden_size), "weight")?;
290-
let intermediate_bias = vb.pp("dense").get(config.hidden_size, "bias")?;
291-
let intermediate = Linear::new(intermediate_weight, Some(intermediate_bias), None);
292-
293-
let output_weight = vb
294-
.pp("out_proj")
295-
.get((n_classes, config.hidden_size), "weight")?;
296-
let output_bias = vb.pp("out_proj").get(n_classes, "bias")?;
297-
let output = Linear::new(output_weight, Some(output_bias), None);
298-
299-
Ok(Self {
300-
intermediate,
301-
output,
302-
span: tracing::span!(tracing::Level::TRACE, "classifier"),
303-
})
304-
}
305-
306-
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
307-
let _enter = self.span.enter();
308-
309-
let hidden_states = self.intermediate.forward(hidden_states)?;
310-
let hidden_states = hidden_states.tanh()?;
311-
let hidden_states = self.output.forward(&hidden_states)?;
312-
313-
Ok(hidden_states)
314-
}
315-
}
316-
317277
pub struct FlashBertModel {
318278
embeddings: BertEmbeddings,
319279
encoder: BertEncoder,
320280
pool: Pool,
321-
classifier: Option<BertClassificationHead>,
281+
classifier: Option<Box<dyn ClassificationHead + Send>>,
322282
pub device: Device,
323283

324284
span: tracing::Span,
@@ -343,13 +303,18 @@ impl FlashBertModel {
343303
let (pool, classifier) = match model_type {
344304
// Classifier models always use CLS pooling
345305
ModelType::Classifier => {
346-
if config.model_type == Some("bert".to_string()) {
347-
candle::bail!("`classifier` model type is not supported for Bert");
348-
}
349-
(
350-
Pool::Cls,
351-
Some(BertClassificationHead::load(vb.pp("classifier"), config)?),
352-
)
306+
let pool = Pool::Cls;
307+
308+
let classifier: Box<dyn ClassificationHead + Send> =
309+
if config.model_type == Some("bert".to_string()) {
310+
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?)
311+
} else {
312+
Box::new(RobertaClassificationHead::load(
313+
vb.pp("classifier"),
314+
config,
315+
)?)
316+
};
317+
(pool, Some(classifier))
353318
}
354319
ModelType::Embedding(pool) => (pool, None),
355320
};

0 commit comments

Comments
 (0)