Skip to content

Commit 4910244

Browse files
committed
Add support for rope_parameters in config.json
See huggingface/transformers#39847
1 parent 2a9f924 commit 4910244

File tree

11 files changed

+107
-14
lines changed

11 files changed

+107
-14
lines changed

backends/candle/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ impl CandleBackend {
538538
rms_norm_eps: config.rms_norm_eps,
539539
model_type: config.model_type.clone(),
540540
rope_theta: config.rope_theta,
541+
rope_parameters: config.rope_parameters,
541542
sliding_window: config.sliding_window,
542543
rope_scaling: config.rope_scaling,
543544
use_bidirectional_attention: config.use_bidirectional_attention,

backends/candle/src/models/flash_gte.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,15 @@ impl FlashGTEModel {
199199
Self::inner_load(vb.pp("new"), config)
200200
.or_else(|_| Self::inner_load(vb.clone(), config))?;
201201

202+
// NOTE: https://github.com/huggingface/transformers/pull/39847
203+
let rope_theta = config.rope_theta.unwrap_or(match config.rope_parameters {
204+
Some(rope_parameters) => rope_parameters.rope_theta,
205+
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`")
206+
});
207+
202208
let inv_freqs = get_inv_freqs(
203209
layers[0].attention.attention_head_size,
204-
config.rope_theta,
210+
rope_theta,
205211
vb.device(),
206212
config.rope_scaling.as_ref(),
207213
)?;

backends/candle/src/models/flash_mistral.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,15 @@ impl FlashMistralModel {
268268

269269
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
270270

271+
// NOTE: https://github.com/huggingface/transformers/pull/39847
272+
let rope_theta = config.rope_theta.unwrap_or(match config.rope_parameters {
273+
Some(rope_parameters) => rope_parameters.rope_theta,
274+
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`")
275+
});
276+
271277
let inv_freqs = get_inv_freqs(
272278
layers[0].attention.attention_head_size,
273-
config.rope_theta,
279+
rope_theta,
274280
vb.device(),
275281
config.rope_scaling.as_ref(),
276282
)?;

backends/candle/src/models/flash_qwen2.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,15 @@ impl FlashQwen2Model {
285285

286286
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
287287

288+
// NOTE: https://github.com/huggingface/transformers/pull/39847
289+
let rope_theta = config.rope_theta.unwrap_or(match config.rope_parameters {
290+
Some(rope_parameters) => rope_parameters.rope_theta,
291+
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`")
292+
});
293+
288294
let inv_freqs = get_inv_freqs(
289295
layers[0].attention.attention_head_size,
290-
config.rope_theta,
296+
rope_theta,
291297
vb.device(),
292298
None,
293299
)?;

backends/candle/src/models/flash_qwen3.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,15 @@ impl FlashQwen3Model {
353353
None
354354
};
355355

356+
// NOTE: https://github.com/huggingface/transformers/pull/39847
357+
let rope_theta = config.rope_theta.unwrap_or(match config.rope_parameters {
358+
Some(rope_parameters) => rope_parameters.rope_theta,
359+
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`")
360+
});
361+
356362
let inv_freqs = get_inv_freqs(
357363
layers[0].attention.attention_head_size,
358-
config.rope_theta,
364+
rope_theta,
359365
vb.device(),
360366
None,
361367
)?;

backends/candle/src/models/gemma3.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ use candle_nn::{Embedding, Module, VarBuilder};
88
use serde::Deserialize;
99
use text_embeddings_backend_core::{Batch, ModelType, Pool};
1010

11+
#[derive(Deserialize)]
12+
struct RopeParameters {
13+
rope_theta: f32,
14+
#[allow(unused)]
15+
rope_type: String,
16+
}
17+
1118
#[derive(Debug, Clone, PartialEq, Deserialize)]
1219
pub struct Gemma3Config {
1320
pub attention_bias: bool,
@@ -23,9 +30,10 @@ pub struct Gemma3Config {
2330
pub query_pre_attn_scalar: usize,
2431
pub rms_norm_eps: f32,
2532
pub rope_local_base_freq: f32,
26-
pub rope_theta: f32,
33+
pub rope_theta: Option<f32>,
34+
pub rope_parameters: Option<RopeParameters>,
2735
pub sliding_window: Option<usize>,
28-
#[serde(rename(deserialize = "_sliding_window_pattern"))]
36+
#[serde(rename = "_sliding_window_pattern")]
2937
pub sliding_window_pattern: usize,
3038
pub vocab_size: usize,
3139
}
@@ -653,7 +661,13 @@ impl Gemma3Model {
653661
.head_dim
654662
.unwrap_or(config.hidden_size / config.num_attention_heads);
655663

656-
let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?;
664+
// NOTE: https://github.com/huggingface/transformers/pull/39847
665+
let rope_theta = config.rope_theta.unwrap_or(match config.rope_parameters {
666+
Some(rope_parameters) => rope_parameters.rope_theta,
667+
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`")
668+
});
669+
670+
let inv_freqs = get_inv_freqs(rotary_dim, rope_theta, vb.device(), None)?;
657671
let rotary_cache =
658672
get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
659673

backends/candle/src/models/gte.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@ use crate::layers::{
33
RopeScaling,
44
};
55
use crate::models::{Model, PositionEmbeddingType};
6+
67
use candle::{DType, Device, IndexOp, Result, Tensor, D};
78
use candle_nn::{Embedding, Module, VarBuilder};
89
use serde::Deserialize;
910
use std::collections::HashMap;
1011
use text_embeddings_backend_core::{Batch, ModelType, Pool};
1112

13+
#[derive(Deserialize)]
14+
struct RopeParameters {
15+
pub rope_theta: f32,
16+
#[allow(unused)]
17+
rope_type: String,
18+
}
19+
1220
#[derive(Debug, Clone, PartialEq, Deserialize)]
1321
pub struct GTEConfig {
1422
pub vocab_size: usize,
@@ -22,7 +30,8 @@ pub struct GTEConfig {
2230
pub layer_norm_type: String,
2331
pub layer_norm_eps: f32,
2432
pub position_embedding_type: PositionEmbeddingType,
25-
pub rope_theta: f32,
33+
pub rope_theta: Option<f32>,
34+
pub rope_parameters: Option<RopeParameters>,
2635
pub rope_scaling: Option<RopeScaling>,
2736
#[serde(default)]
2837
pub logn_attention_scale: bool,
@@ -412,10 +421,16 @@ impl GTEModel {
412421
Self::inner_load(vb.pp("new"), config)
413422
.or_else(|_| Self::inner_load(vb.clone(), config))?;
414423

424+
// NOTE: https://github.com/huggingface/transformers/pull/39847
425+
let rope_theta = config.rope_theta.unwrap_or(match config.rope_parameters {
426+
Some(rope_parameters) => rope_parameters.rope_theta,
427+
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` is defined in the `config.json`")
428+
});
429+
415430
let rotary_dim = encoder.layers[0].attention.attention_head_size;
416431
let inv_freqs = get_inv_freqs(
417432
rotary_dim,
418-
config.rope_theta,
433+
rope_theta,
419434
vb.device(),
420435
config.rope_scaling.as_ref(),
421436
)?;

backends/candle/src/models/llama.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
use crate::layers::{HiddenAct, RopeScaling};
22
use serde::Deserialize;
33

4+
#[derive(Deserialize)]
5+
struct RopeParameters {
6+
pub rope_theta: f32,
7+
#[allow(unused)]
8+
rope_type: String,
9+
}
10+
411
#[derive(Debug, Clone, PartialEq, Deserialize)]
512
pub struct LlamaConfig {
613
pub vocab_size: usize,
@@ -14,7 +21,8 @@ pub struct LlamaConfig {
1421
pub initializer_range: f64,
1522
pub rms_norm_eps: f32,
1623
pub model_type: Option<String>,
17-
pub rope_theta: f32,
24+
pub rope_theta: Option<f32>,
25+
pub rope_parameters: Option<RopeParameters>,
1826
pub sliding_window: Option<usize>,
1927
pub rope_scaling: Option<RopeScaling>,
2028
#[serde(default)]

backends/candle/src/models/mistral.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
use crate::layers::{HiddenAct, RopeScaling};
22
use serde::Deserialize;
33

4+
#[derive(Deserialize)]
5+
struct RopeParameters {
6+
pub rope_theta: f32,
7+
#[allow(unused)]
8+
rope_type: String,
9+
}
10+
411
#[derive(Debug, Clone, PartialEq, Deserialize)]
512
pub struct MistralConfig {
613
pub vocab_size: usize,
@@ -14,7 +21,8 @@ pub struct MistralConfig {
1421
pub initializer_range: f64,
1522
pub rms_norm_eps: f32,
1623
pub model_type: Option<String>,
17-
pub rope_theta: f32,
24+
pub rope_theta: Option<f32>,
25+
pub rope_parameters: Option<RopeParameters>,
1826
pub sliding_window: Option<usize>,
1927
pub rope_scaling: Option<RopeScaling>,
2028
#[serde(default)]

backends/candle/src/models/qwen2.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ fn default_is_causal() -> bool {
66
true
77
}
88

9+
#[derive(Deserialize)]
10+
struct RopeParameters {
11+
pub rope_theta: f32,
12+
#[allow(unused)]
13+
rope_type: String,
14+
}
15+
916
#[derive(Debug, Clone, PartialEq, Deserialize)]
1017
pub struct Qwen2Config {
1118
pub vocab_size: usize,
@@ -17,7 +24,8 @@ pub struct Qwen2Config {
1724
pub hidden_act: HiddenAct,
1825
pub max_position_embeddings: usize,
1926
pub rms_norm_eps: f32,
20-
pub rope_theta: f32,
27+
pub rope_theta: Option<f32>,
28+
pub rope_parameters: Option<RopeParameters>,
2129
pub sliding_window: Option<usize>,
2230
pub use_sliding_window: bool,
2331
#[serde(default = "default_is_causal")]

0 commit comments

Comments
 (0)