@@ -440,11 +440,95 @@ impl ClassificationHead for RobertaClassificationHead {
440
440
}
441
441
}
442
442
443
+ #[ derive( Debug ) ]
444
+ pub struct BertSpladeHead {
445
+ transform : Linear ,
446
+ transform_layer_norm : LayerNorm ,
447
+ decoder : Linear ,
448
+ span : tracing:: Span ,
449
+ }
450
+
451
+ impl BertSpladeHead {
452
+ pub ( crate ) fn load ( vb : VarBuilder , config : & BertConfig ) -> Result < Self > {
453
+ let vb = vb. pp ( "cls.predictions" ) ;
454
+ let transform_weight = vb
455
+ . pp ( "transform.dense" )
456
+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
457
+ let transform_bias = vb. pp ( "transform.dense" ) . get ( config. hidden_size , "bias" ) ?;
458
+ let transform = Linear :: new (
459
+ transform_weight,
460
+ Some ( transform_bias) ,
461
+ Some ( config. hidden_act . clone ( ) ) ,
462
+ ) ;
463
+
464
+ let transform_layer_norm = LayerNorm :: load (
465
+ vb. pp ( "transform.LayerNorm" ) ,
466
+ config. hidden_size ,
467
+ config. layer_norm_eps as f32 ,
468
+ ) ?;
469
+
470
+ let decoder_weight = vb
471
+ . pp ( "decoder" )
472
+ . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?;
473
+ let decoder_bias = vb. get ( config. vocab_size , "bias" ) ?;
474
+ let decoder = Linear :: new ( decoder_weight, Some ( decoder_bias) , Some ( HiddenAct :: Relu ) ) ;
475
+
476
+ Ok ( Self {
477
+ transform,
478
+ transform_layer_norm,
479
+ decoder,
480
+ span : tracing:: span!( tracing:: Level :: TRACE , "splade" ) ,
481
+ } )
482
+ }
483
+
484
+ pub ( crate ) fn load_roberta ( vb : VarBuilder , config : & BertConfig ) -> Result < Self > {
485
+ let vb = vb. pp ( "lm_head" ) ;
486
+ let transform_weight = vb
487
+ . pp ( "dense" )
488
+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
489
+ let transform_bias = vb. pp ( "dense" ) . get ( config. hidden_size , "bias" ) ?;
490
+ let transform = Linear :: new (
491
+ transform_weight,
492
+ Some ( transform_bias) ,
493
+ Some ( HiddenAct :: Gelu ) ,
494
+ ) ;
495
+
496
+ let transform_layer_norm = LayerNorm :: load (
497
+ vb. pp ( "layer_norm" ) ,
498
+ config. hidden_size ,
499
+ config. layer_norm_eps as f32 ,
500
+ ) ?;
501
+
502
+ let decoder_weight = vb
503
+ . pp ( "decoder" )
504
+ . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?;
505
+ let decoder_bias = vb. get ( config. vocab_size , "bias" ) ?;
506
+ let decoder = Linear :: new ( decoder_weight, Some ( decoder_bias) , Some ( HiddenAct :: Relu ) ) ;
507
+
508
+ Ok ( Self {
509
+ transform,
510
+ transform_layer_norm,
511
+ decoder,
512
+ span : tracing:: span!( tracing:: Level :: TRACE , "splade" ) ,
513
+ } )
514
+ }
515
+
516
+ pub ( crate ) fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
517
+ let _enter = self . span . enter ( ) ;
518
+
519
+ let hidden_states = self . transform . forward ( hidden_states) ?;
520
+ let hidden_states = self . transform_layer_norm . forward ( & hidden_states, None ) ?;
521
+ let hidden_states = self . decoder . forward ( & hidden_states) ?;
522
+ ( 1.0 + hidden_states) ?. log ( )
523
+ }
524
+ }
525
+
443
526
pub struct BertModel {
444
527
embeddings : BertEmbeddings ,
445
528
encoder : BertEncoder ,
446
529
pool : Pool ,
447
530
classifier : Option < Box < dyn ClassificationHead + Send > > ,
531
+ splade : Option < BertSpladeHead > ,
448
532
449
533
num_attention_heads : usize ,
450
534
@@ -461,20 +545,22 @@ impl BertModel {
461
545
candle:: bail!( "Bert only supports absolute position embeddings" )
462
546
}
463
547
464
- let ( pool, classifier) = match model_type {
548
+ let ( pool, classifier, splade ) = match model_type {
465
549
// Classifier models always use CLS pooling
466
550
ModelType :: Classifier => {
467
551
let pool = Pool :: Cls ;
468
552
469
553
let classifier: Box < dyn ClassificationHead + Send > =
470
554
Box :: new ( BertClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?) ;
471
- ( pool, Some ( classifier) )
555
+ ( pool, Some ( classifier) , None )
472
556
}
473
557
ModelType :: Embedding ( pool) => {
474
- if pool == Pool :: Splade {
475
- candle:: bail!( "`splade` is not supported for Nomic" )
476
- }
477
- ( pool, None )
558
+ let splade = if pool == Pool :: Splade {
559
+ Some ( BertSpladeHead :: load ( vb. clone ( ) , config) ?)
560
+ } else {
561
+ None
562
+ } ;
563
+ ( pool, None , splade)
478
564
}
479
565
} ;
480
566
@@ -500,6 +586,7 @@ impl BertModel {
500
586
encoder,
501
587
pool,
502
588
classifier,
589
+ splade,
503
590
num_attention_heads : config. num_attention_heads ,
504
591
device : vb. device ( ) . clone ( ) ,
505
592
dtype : vb. dtype ( ) ,
@@ -517,17 +604,24 @@ impl BertModel {
517
604
candle:: bail!( "Bert only supports absolute position embeddings" )
518
605
}
519
606
520
- let ( pool, classifier) = match model_type {
607
+ let ( pool, classifier, splade ) = match model_type {
521
608
// Classifier models always use CLS pooling
522
609
ModelType :: Classifier => {
523
610
let pool = Pool :: Cls ;
524
611
525
612
let classifier: Box < dyn ClassificationHead + Send > = Box :: new (
526
613
RobertaClassificationHead :: load ( vb. pp ( "classifier" ) , config) ?,
527
614
) ;
528
- ( pool, Some ( classifier) )
615
+ ( pool, Some ( classifier) , None )
616
+ }
617
+ ModelType :: Embedding ( pool) => {
618
+ let splade = if pool == Pool :: Splade {
619
+ Some ( BertSpladeHead :: load_roberta ( vb. clone ( ) , config) ?)
620
+ } else {
621
+ None
622
+ } ;
623
+ ( pool, None , splade)
529
624
}
530
- ModelType :: Embedding ( pool) => ( pool, None ) ,
531
625
} ;
532
626
533
627
let ( embeddings, encoder) = match (
@@ -562,6 +656,7 @@ impl BertModel {
562
656
encoder,
563
657
pool,
564
658
classifier,
659
+ splade,
565
660
num_attention_heads : config. num_attention_heads ,
566
661
device : vb. device ( ) . clone ( ) ,
567
662
dtype : vb. dtype ( ) ,
@@ -730,7 +825,25 @@ impl BertModel {
730
825
731
826
( outputs. sum ( 1 ) ?. broadcast_div ( & input_lengths) ) ?
732
827
}
733
- Pool :: Splade => unreachable ! ( ) ,
828
+ Pool :: Splade => {
829
+ // Unwrap is safe here
830
+ let splade_head = self . splade . as_ref ( ) . unwrap ( ) ;
831
+ let mut relu_log = splade_head. forward ( & outputs) ?;
832
+
833
+ if let Some ( ref attention_mask) = attention_mask {
834
+ let mut attention_mask = attention_mask. clone ( ) ;
835
+
836
+ if let Some ( pooled_indices) = pooled_indices {
837
+ // Select values in the batch
838
+ attention_mask = attention_mask. index_select ( & pooled_indices, 0 ) ?;
839
+ } ;
840
+
841
+ // Mask padded values
842
+ relu_log = relu_log. broadcast_mul ( & attention_mask) ?;
843
+ }
844
+
845
+ relu_log. max ( 1 ) ?
846
+ }
734
847
} ;
735
848
Some ( pooled_embeddings)
736
849
} else {
0 commit comments