@@ -359,14 +359,49 @@ impl BertEncoder {
359
359
}
360
360
}
361
361
362
- struct BertClassificationHead {
363
- intermediate : Linear ,
362
+ pub trait ClassificationHead {
363
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
364
+ }
365
+
366
+ pub struct BertClassificationHead {
364
367
output : Linear ,
365
368
span : tracing:: Span ,
366
369
}
367
370
368
371
impl BertClassificationHead {
369
- pub fn load ( vb : VarBuilder , config : & Config ) -> Result < Self > {
372
+ pub ( crate ) fn load ( vb : VarBuilder , config : & Config ) -> Result < Self > {
373
+ let n_classes = match & config. id2label {
374
+ None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
375
+ Some ( id2label) => id2label. len ( ) ,
376
+ } ;
377
+
378
+ let output_weight = vb. get ( ( n_classes, config. hidden_size ) , "weight" ) ?;
379
+ let output_bias = vb. get ( n_classes, "bias" ) ?;
380
+ let output = Linear :: new ( output_weight, Some ( output_bias) , None ) ;
381
+
382
+ Ok ( Self {
383
+ output,
384
+ span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
385
+ } )
386
+ }
387
+ }
388
+
389
+ impl ClassificationHead for BertClassificationHead {
390
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
391
+ let _enter = self . span . enter ( ) ;
392
+ let hidden_states = self . output . forward ( & hidden_states) ?;
393
+ Ok ( hidden_states)
394
+ }
395
+ }
396
+
397
+ pub struct RobertaClassificationHead {
398
+ intermediate : Linear ,
399
+ output : Linear ,
400
+ span : tracing:: Span ,
401
+ }
402
+
403
+ impl RobertaClassificationHead {
404
+ pub ( crate ) fn load ( vb : VarBuilder , config : & Config ) -> Result < Self > {
370
405
let n_classes = match & config. id2label {
371
406
None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
372
407
Some ( id2label) => id2label. len ( ) ,
@@ -390,8 +425,10 @@ impl BertClassificationHead {
390
425
span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
391
426
} )
392
427
}
428
+ }
393
429
394
- pub fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
430
+ impl ClassificationHead for RobertaClassificationHead {
431
+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
395
432
let _enter = self . span . enter ( ) ;
396
433
397
434
let hidden_states = self . intermediate . forward ( hidden_states) ?;
@@ -406,7 +443,7 @@ pub struct BertModel {
406
443
embeddings : BertEmbeddings ,
407
444
encoder : BertEncoder ,
408
445
pool : Pool ,
409
- classifier : Option < BertClassificationHead > ,
446
+ classifier : Option < Box < dyn ClassificationHead + Send > > ,
410
447
411
448
num_attention_heads : usize ,
412
449
@@ -426,13 +463,18 @@ impl BertModel {
426
463
let ( pool, classifier) = match model_type {
427
464
// Classifier models always use CLS pooling
428
465
ModelType :: Classifier => {
429
- if config. model_type == Some ( "bert" . to_string ( ) ) {
430
- candle:: bail!( "`classifier` model type is not supported for Bert" ) ;
431
- }
432
- (
433
- Pool :: Cls ,
434
- Some ( BertClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?) ,
435
- )
466
+ let pool = Pool :: Cls ;
467
+
468
+ let classifier: Box < dyn ClassificationHead + Send > =
469
+ if config. model_type == Some ( "bert" . to_string ( ) ) {
470
+ Box :: new ( BertClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?)
471
+ } else {
472
+ Box :: new ( RobertaClassificationHead :: load (
473
+ vb. pp ( "classifier" ) ,
474
+ config,
475
+ ) ?)
476
+ } ;
477
+ ( pool, Some ( classifier) )
436
478
}
437
479
ModelType :: Embedding ( pool) => ( pool, None ) ,
438
480
} ;
0 commit comments