Skip to content

Commit 6c4016c

Browse files
committed
address comments
1 parent 577220b commit 6c4016c

File tree

2 files changed

+58
-68
lines changed

2 files changed

+58
-68
lines changed

mlx-lm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ description = "Rust implementation of mlx-lm"
1313
[dependencies]
1414
# Local dependencies
1515
mlx-rs.workspace = true
16+
mlx-macros.workspace = true
1617
mlx-lm-utils.workspace = true
1718

1819
# External dependencies

mlx-lm/src/utils/rope.rs

Lines changed: 57 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
use std::collections::HashMap;
22

3-
use mlx_rs::{builder::Builder, error::Exception, module::Module, nn, Array};
3+
use mlx_macros::ModuleParameters;
4+
use mlx_rs::{
5+
builder::Builder,
6+
error::Exception,
7+
module::Module,
8+
nn,
9+
ops::{arange, which},
10+
Array,
11+
};
412
use serde::Deserialize;
513

614
#[derive(Debug, Clone, PartialEq)]
@@ -9,7 +17,7 @@ pub enum FloatOrStr<'a> {
917
Str(&'a str),
1018
}
1119

12-
// TODO: check if additionl serde attributes are needed
20+
// TODO: check if additional serde attributes are needed
1321
#[derive(Debug, Clone, Deserialize)]
1422
#[serde(untagged)]
1523
pub enum FloatOrString {
@@ -26,7 +34,12 @@ impl FloatOrString {
2634
}
2735
}
2836

29-
fn get_float_from_config(
37+
/// Get a numeric float value from a scaling config by key.
38+
///
39+
/// Note: str variants in the config are not always floats — values like "default" or "linear"
40+
/// are also valid for non-numeric fields. This function should only be called for keys that
41+
/// are expected to hold numeric values.
42+
fn get_numeric_from_config(
3043
config: &HashMap<String, FloatOrString>,
3144
key: &str,
3245
) -> Result<f32, Exception> {
@@ -39,19 +52,21 @@ fn get_float_from_config(
3952
FloatOrStr::Float(f) => Ok(f),
4053
FloatOrStr::Str(s) => s
4154
.parse::<f32>()
42-
.map_err(|_| Exception::custom(format!(r#"key "{key}" is not a valid float"#))),
55+
.map_err(|_| Exception::custom(format!(r#"key "{key}" is not a valid number"#))),
4356
}
4457
}
4558

4659
/// Llama3-style RoPE with frequency scaling.
4760
///
4861
/// Applies piecewise frequency scaling based on wavelength cutoffs derived from
4962
/// `low_freq_factor`, `high_freq_factor`, `factor`, and `original_max_position_embeddings`.
50-
#[derive(Debug, Clone)]
63+
// TODO: support derive ModuleParameters for structs with non-param Array fields
64+
#[derive(Debug, Clone, ModuleParameters)]
5165
pub struct Llama3Rope {
5266
pub dimensions: i32,
5367
pub traditional: bool,
5468
pub scale: f32,
69+
/// Pre-computed scaled frequencies. Not a module parameter.
5570
pub freqs: Array,
5671
}
5772

@@ -67,12 +82,12 @@ impl Llama3Rope {
6782
) -> Result<Self, Exception> {
6883
let half_dims = dims / 2;
6984

70-
// Compute freqs as periods: base^(2i/dims), matching Python:
85+
// Compute freqs using MLX ops, matching Python:
7186
// freqs = base ** (mx.arange(0, dims, 2) / dims)
72-
let mut freqs = Vec::with_capacity(half_dims as usize);
73-
for i in 0..half_dims {
74-
freqs.push(base.powf(2.0 * i as f32 / dims as f32));
75-
}
87+
// which equals base^(2i/dims) for i in 0..half_dims
88+
let indices = arange::<_, f32>(None, half_dims, None)?;
89+
let exponents = indices.multiply(Array::from_f32(2.0 / dims as f32))?;
90+
let freqs = Array::from_f32(base).power(&exponents)?;
7691

7792
let old_context_len = original_max_position_embeddings as f32;
7893
let low_freq_wavelen = old_context_len / low_freq_factor;
@@ -85,68 +100,42 @@ impl Llama3Rope {
85100
// smooth_factors = (old_context_len / wavelens - low_freq_factor) / (high - low)
86101
// smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
87102
// freqs = where(is_medium, smooth_freqs, freqs)
88-
let mut scaled_freqs = Vec::with_capacity(half_dims as usize);
89-
for &freq in &freqs {
90-
let wavelen = 2.0 * std::f32::consts::PI * freq;
91-
// First pass: scale low frequencies (long wavelengths) by factor
92-
let freq = if wavelen > low_freq_wavelen {
93-
freq * factor
94-
} else {
95-
freq
96-
};
97-
// Second pass: apply smooth interpolation for medium frequencies
98-
let is_medium = wavelen > high_freq_wavelen && wavelen < low_freq_wavelen;
99-
if is_medium {
100-
let smooth_factor = (old_context_len / wavelen - low_freq_factor)
101-
/ (high_freq_factor - low_freq_factor);
102-
let smooth_freq = freq / ((1.0 - smooth_factor) / factor + smooth_factor);
103-
scaled_freqs.push(smooth_freq);
104-
} else {
105-
scaled_freqs.push(freq);
106-
}
107-
}
103+
let two_pi = Array::from_f32(2.0 * std::f32::consts::PI);
104+
let wavelens = freqs.multiply(&two_pi)?;
108105

109-
let freqs_array = Array::from_slice(&scaled_freqs, &[half_dims]);
106+
// First pass: scale low frequencies (long wavelengths) by factor
107+
let is_low = wavelens.gt(Array::from_f32(low_freq_wavelen))?;
108+
let freqs = which(&is_low, &freqs.multiply(Array::from_f32(factor))?, &freqs)?;
109+
110+
// Second pass: smooth interpolation for medium frequencies
111+
let is_medium = wavelens
112+
.gt(Array::from_f32(high_freq_wavelen))?
113+
.logical_and(&wavelens.lt(Array::from_f32(low_freq_wavelen))?)?;
114+
115+
let smooth_factors = wavelens
116+
.reciprocal()?
117+
.multiply(Array::from_f32(old_context_len))?
118+
.subtract(Array::from_f32(low_freq_factor))?
119+
.divide(Array::from_f32(high_freq_factor - low_freq_factor))?;
120+
121+
// smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
122+
let one_minus_smooth = Array::from_f32(1.0).subtract(&smooth_factors)?;
123+
let denom = one_minus_smooth
124+
.divide(Array::from_f32(factor))?
125+
.add(&smooth_factors)?;
126+
let smooth_freqs = freqs.divide(&denom)?;
127+
128+
let freqs = which(&is_medium, &smooth_freqs, &freqs)?;
110129

111130
Ok(Self {
112131
dimensions: dims,
113132
traditional,
114133
scale: 1.0,
115-
freqs: freqs_array,
134+
freqs,
116135
})
117136
}
118137
}
119138

120-
impl mlx_rs::module::ModuleParameters for Llama3Rope {
121-
fn num_parameters(&self) -> usize {
122-
0
123-
}
124-
125-
fn freeze_parameters(&mut self, _recursive: bool) {}
126-
127-
fn unfreeze_parameters(&mut self, _recursive: bool) {}
128-
129-
fn parameters(&self) -> mlx_rs::module::ModuleParamRef<'_> {
130-
mlx_rs::nested::NestedHashMap::new()
131-
}
132-
133-
fn parameters_mut(&mut self) -> mlx_rs::module::ModuleParamMut<'_> {
134-
mlx_rs::nested::NestedHashMap::new()
135-
}
136-
137-
fn trainable_parameters(&self) -> mlx_rs::module::ModuleParamRef<'_> {
138-
mlx_rs::nested::NestedHashMap::new()
139-
}
140-
141-
fn all_frozen(&self) -> Option<bool> {
142-
None
143-
}
144-
145-
fn any_frozen(&self) -> Option<bool> {
146-
None
147-
}
148-
}
149-
150139
impl<'a, Input> Module<Input> for Llama3Rope
151140
where
152141
Input: Into<nn::RopeInput<'a>>,
@@ -181,6 +170,7 @@ pub enum RopeVariant {
181170
Llama3(Llama3Rope),
182171
}
183172

173+
// TODO: support derive ModuleParameters for enum
184174
impl mlx_rs::module::ModuleParameters for RopeVariant {
185175
fn num_parameters(&self) -> usize {
186176
0
@@ -256,8 +246,7 @@ pub fn initialize_rope(
256246

257247
if rope_type == FloatOrStr::Str("default") || rope_type == FloatOrStr::Str("linear") {
258248
let scale = if rope_type == FloatOrStr::Str("linear") {
259-
let den = get_float_from_config(scaling_config.as_ref().unwrap(), "factor")?;
260-
249+
let den = get_numeric_from_config(scaling_config.as_ref().unwrap(), "factor")?;
261250
1.0 / den
262251
} else {
263252
1.0
@@ -275,11 +264,11 @@ pub fn initialize_rope(
275264
.as_ref()
276265
.ok_or_else(|| Exception::custom("scaling_config is required for llama3 RoPE"))?;
277266

278-
let factor = get_float_from_config(config, "factor")?;
279-
let low_freq_factor = get_float_from_config(config, "low_freq_factor")?;
280-
let high_freq_factor = get_float_from_config(config, "high_freq_factor")?;
267+
let factor = get_numeric_from_config(config, "factor")?;
268+
let low_freq_factor = get_numeric_from_config(config, "low_freq_factor")?;
269+
let high_freq_factor = get_numeric_from_config(config, "high_freq_factor")?;
281270
let original_max_position_embeddings =
282-
get_float_from_config(config, "original_max_position_embeddings")? as i32;
271+
get_numeric_from_config(config, "original_max_position_embeddings")? as i32;
283272

284273
let rope = Llama3Rope::new(
285274
dims,

0 commit comments

Comments
 (0)