11use std:: collections:: HashMap ;
22
3- use mlx_rs:: { builder:: Builder , error:: Exception , module:: Module , nn, Array } ;
3+ use mlx_macros:: ModuleParameters ;
4+ use mlx_rs:: {
5+ builder:: Builder ,
6+ error:: Exception ,
7+ module:: Module ,
8+ nn,
9+ ops:: { arange, which} ,
10+ Array ,
11+ } ;
412use serde:: Deserialize ;
513
614#[ derive( Debug , Clone , PartialEq ) ]
@@ -9,7 +17,7 @@ pub enum FloatOrStr<'a> {
917 Str ( & ' a str ) ,
1018}
1119
12- // TODO: check if additionl serde attributes are needed
20+ // TODO: check if additional serde attributes are needed
1321#[ derive( Debug , Clone , Deserialize ) ]
1422#[ serde( untagged) ]
1523pub enum FloatOrString {
@@ -26,7 +34,12 @@ impl FloatOrString {
2634 }
2735}
2836
29- fn get_float_from_config (
37+ /// Get a numeric float value from a scaling config by key.
38+ ///
39+ /// Note: str variants in the config are not always floats — values like "default" or "linear"
40+ /// are also valid for non-numeric fields. This function should only be called for keys that
41+ /// are expected to hold numeric values.
42+ fn get_numeric_from_config (
3043 config : & HashMap < String , FloatOrString > ,
3144 key : & str ,
3245) -> Result < f32 , Exception > {
@@ -39,19 +52,21 @@ fn get_float_from_config(
3952 FloatOrStr :: Float ( f) => Ok ( f) ,
4053 FloatOrStr :: Str ( s) => s
4154 . parse :: < f32 > ( )
42- . map_err ( |_| Exception :: custom ( format ! ( r#"key "{key}" is not a valid float "# ) ) ) ,
55+ . map_err ( |_| Exception :: custom ( format ! ( r#"key "{key}" is not a valid number "# ) ) ) ,
4356 }
4457}
4558
4659/// Llama3-style RoPE with frequency scaling.
4760///
4861/// Applies piecewise frequency scaling based on wavelength cutoffs derived from
4962/// `low_freq_factor`, `high_freq_factor`, `factor`, and `original_max_position_embeddings`.
50- #[ derive( Debug , Clone ) ]
63+ // TODO: support derive ModuleParameters for structs with non-param Array fields
64+ #[ derive( Debug , Clone , ModuleParameters ) ]
5165pub struct Llama3Rope {
5266 pub dimensions : i32 ,
5367 pub traditional : bool ,
5468 pub scale : f32 ,
69+ /// Pre-computed scaled frequencies. Not a module parameter.
5570 pub freqs : Array ,
5671}
5772
@@ -67,12 +82,12 @@ impl Llama3Rope {
6782 ) -> Result < Self , Exception > {
6883 let half_dims = dims / 2 ;
6984
70- // Compute freqs as periods: base^(2i/dims) , matching Python:
85+ // Compute freqs using MLX ops , matching Python:
7186 // freqs = base ** (mx.arange(0, dims, 2) / dims)
72- let mut freqs = Vec :: with_capacity ( half_dims as usize ) ;
73- for i in 0 .. half_dims {
74- freqs . push ( base . powf ( 2.0 * i as f32 / dims as f32 ) ) ;
75- }
87+ // which equals base^(2i/dims) for i in 0..half_dims
88+ let indices = arange :: < _ , f32 > ( None , half_dims, None ) ? ;
89+ let exponents = indices . multiply ( Array :: from_f32 ( 2.0 / dims as f32 ) ) ? ;
90+ let freqs = Array :: from_f32 ( base ) . power ( & exponents ) ? ;
7691
7792 let old_context_len = original_max_position_embeddings as f32 ;
7893 let low_freq_wavelen = old_context_len / low_freq_factor;
@@ -85,68 +100,42 @@ impl Llama3Rope {
85100 // smooth_factors = (old_context_len / wavelens - low_freq_factor) / (high - low)
86101 // smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
87102 // freqs = where(is_medium, smooth_freqs, freqs)
88- let mut scaled_freqs = Vec :: with_capacity ( half_dims as usize ) ;
89- for & freq in & freqs {
90- let wavelen = 2.0 * std:: f32:: consts:: PI * freq;
91- // First pass: scale low frequencies (long wavelengths) by factor
92- let freq = if wavelen > low_freq_wavelen {
93- freq * factor
94- } else {
95- freq
96- } ;
97- // Second pass: apply smooth interpolation for medium frequencies
98- let is_medium = wavelen > high_freq_wavelen && wavelen < low_freq_wavelen;
99- if is_medium {
100- let smooth_factor = ( old_context_len / wavelen - low_freq_factor)
101- / ( high_freq_factor - low_freq_factor) ;
102- let smooth_freq = freq / ( ( 1.0 - smooth_factor) / factor + smooth_factor) ;
103- scaled_freqs. push ( smooth_freq) ;
104- } else {
105- scaled_freqs. push ( freq) ;
106- }
107- }
103+ let two_pi = Array :: from_f32 ( 2.0 * std:: f32:: consts:: PI ) ;
104+ let wavelens = freqs. multiply ( & two_pi) ?;
108105
109- let freqs_array = Array :: from_slice ( & scaled_freqs, & [ half_dims] ) ;
106+ // First pass: scale low frequencies (long wavelengths) by factor
107+ let is_low = wavelens. gt ( Array :: from_f32 ( low_freq_wavelen) ) ?;
108+ let freqs = which ( & is_low, & freqs. multiply ( Array :: from_f32 ( factor) ) ?, & freqs) ?;
109+
110+ // Second pass: smooth interpolation for medium frequencies
111+ let is_medium = wavelens
112+ . gt ( Array :: from_f32 ( high_freq_wavelen) ) ?
113+ . logical_and ( & wavelens. lt ( Array :: from_f32 ( low_freq_wavelen) ) ?) ?;
114+
115+ let smooth_factors = wavelens
116+ . reciprocal ( ) ?
117+ . multiply ( Array :: from_f32 ( old_context_len) ) ?
118+ . subtract ( Array :: from_f32 ( low_freq_factor) ) ?
119+ . divide ( Array :: from_f32 ( high_freq_factor - low_freq_factor) ) ?;
120+
121+ // smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
122+ let one_minus_smooth = Array :: from_f32 ( 1.0 ) . subtract ( & smooth_factors) ?;
123+ let denom = one_minus_smooth
124+ . divide ( Array :: from_f32 ( factor) ) ?
125+ . add ( & smooth_factors) ?;
126+ let smooth_freqs = freqs. divide ( & denom) ?;
127+
128+ let freqs = which ( & is_medium, & smooth_freqs, & freqs) ?;
110129
111130 Ok ( Self {
112131 dimensions : dims,
113132 traditional,
114133 scale : 1.0 ,
115- freqs : freqs_array ,
134+ freqs,
116135 } )
117136 }
118137}
119138
120- impl mlx_rs:: module:: ModuleParameters for Llama3Rope {
121- fn num_parameters ( & self ) -> usize {
122- 0
123- }
124-
125- fn freeze_parameters ( & mut self , _recursive : bool ) { }
126-
127- fn unfreeze_parameters ( & mut self , _recursive : bool ) { }
128-
129- fn parameters ( & self ) -> mlx_rs:: module:: ModuleParamRef < ' _ > {
130- mlx_rs:: nested:: NestedHashMap :: new ( )
131- }
132-
133- fn parameters_mut ( & mut self ) -> mlx_rs:: module:: ModuleParamMut < ' _ > {
134- mlx_rs:: nested:: NestedHashMap :: new ( )
135- }
136-
137- fn trainable_parameters ( & self ) -> mlx_rs:: module:: ModuleParamRef < ' _ > {
138- mlx_rs:: nested:: NestedHashMap :: new ( )
139- }
140-
141- fn all_frozen ( & self ) -> Option < bool > {
142- None
143- }
144-
145- fn any_frozen ( & self ) -> Option < bool > {
146- None
147- }
148- }
149-
150139impl < ' a , Input > Module < Input > for Llama3Rope
151140where
152141 Input : Into < nn:: RopeInput < ' a > > ,
@@ -181,6 +170,7 @@ pub enum RopeVariant {
181170 Llama3 ( Llama3Rope ) ,
182171}
183172
173+ // TODO: support derive ModuleParameters for enum
184174impl mlx_rs:: module:: ModuleParameters for RopeVariant {
185175 fn num_parameters ( & self ) -> usize {
186176 0
@@ -256,8 +246,7 @@ pub fn initialize_rope(
256246
257247 if rope_type == FloatOrStr :: Str ( "default" ) || rope_type == FloatOrStr :: Str ( "linear" ) {
258248 let scale = if rope_type == FloatOrStr :: Str ( "linear" ) {
259- let den = get_float_from_config ( scaling_config. as_ref ( ) . unwrap ( ) , "factor" ) ?;
260-
249+ let den = get_numeric_from_config ( scaling_config. as_ref ( ) . unwrap ( ) , "factor" ) ?;
261250 1.0 / den
262251 } else {
263252 1.0
@@ -275,11 +264,11 @@ pub fn initialize_rope(
275264 . as_ref ( )
276265 . ok_or_else ( || Exception :: custom ( "scaling_config is required for llama3 RoPE" ) ) ?;
277266
278- let factor = get_float_from_config ( config, "factor" ) ?;
279- let low_freq_factor = get_float_from_config ( config, "low_freq_factor" ) ?;
280- let high_freq_factor = get_float_from_config ( config, "high_freq_factor" ) ?;
267+ let factor = get_numeric_from_config ( config, "factor" ) ?;
268+ let low_freq_factor = get_numeric_from_config ( config, "low_freq_factor" ) ?;
269+ let high_freq_factor = get_numeric_from_config ( config, "high_freq_factor" ) ?;
281270 let original_max_position_embeddings =
282- get_float_from_config ( config, "original_max_position_embeddings" ) ? as i32 ;
271+ get_numeric_from_config ( config, "original_max_position_embeddings" ) ? as i32 ;
283272
284273 let rope = Llama3Rope :: new (
285274 dims,
0 commit comments