@@ -3,12 +3,20 @@ use crate::layers::{
33 RopeScaling ,
44} ;
55use crate :: models:: { Model , PositionEmbeddingType } ;
6+
67use candle:: { DType , Device , IndexOp , Result , Tensor , D } ;
78use candle_nn:: { Embedding , Module , VarBuilder } ;
89use serde:: Deserialize ;
910use std:: collections:: HashMap ;
1011use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
1112
13+ #[ derive( Deserialize ) ]
14+ struct RopeParameters {
15+ pub rope_theta : f32 ,
16+ #[ allow( unused) ]
17+ rope_type : String ,
18+ }
19+
1220#[ derive( Debug , Clone , PartialEq , Deserialize ) ]
1321pub struct GTEConfig {
1422 pub vocab_size : usize ,
@@ -22,7 +30,8 @@ pub struct GTEConfig {
2230 pub layer_norm_type : String ,
2331 pub layer_norm_eps : f32 ,
2432 pub position_embedding_type : PositionEmbeddingType ,
25- pub rope_theta : f32 ,
33+ pub rope_theta : Option < f32 > ,
34+ pub rope_parameters : Option < RopeParameters > ,
2635 pub rope_scaling : Option < RopeScaling > ,
2736 #[ serde( default ) ]
2837 pub logn_attention_scale : bool ,
@@ -412,10 +421,16 @@ impl GTEModel {
412421 Self :: inner_load ( vb. pp ( "new" ) , config)
413422 . or_else ( |_| Self :: inner_load ( vb. clone ( ) , config) ) ?;
414423
424+ // NOTE: https://github.com/huggingface/transformers/pull/39847
425+ let rope_theta = config. rope_theta . unwrap_or ( match config. rope_parameters {
426+ Some ( rope_parameters) => rope_parameters. rope_theta ,
427+ None => candle:: bail!( "Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`" )
428+ } ) ;
429+
415430 let rotary_dim = encoder. layers [ 0 ] . attention . attention_head_size ;
416431 let inv_freqs = get_inv_freqs (
417432 rotary_dim,
418- config . rope_theta ,
433+ rope_theta,
419434 vb. device ( ) ,
420435 config. rope_scaling . as_ref ( ) ,
421436 ) ?;
0 commit comments