@@ -159,10 +159,13 @@ struct Params {
159159 freq_scale : f32 ,
160160 ext_factor : f32 ,
161161 corr_dim0 : f32 ,
162- corr_dim1 : f32
162+ corr_dim1 : f32 ,
163+ sections0 : u32 ,
164+ sections1 : u32 ,
165+ sections2 : u32 ,
166+ sections3 : u32
163167};
164168
165-
166169@group (0 ) @binding (0 )
167170var <storage , read_write > src0 : array <{{TYPE }}>;
168171
@@ -189,19 +192,21 @@ fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
189192 return vec2 <f32 >(cos (theta ) * mscale , sin (theta ) * mscale );
190193}
191194
192- fn pair_base (i0 : u32 ) -> u32 {
193- switch ( params . mode ) {
194- case 0 { return i0 ; } // norm
195- case 2 { return i0 / 2 ; } // neox
196- default { return 1 ; }
195+ fn pair_base (i0 : u32 , div_2 : bool ) -> u32 {
196+ if ( div_2 ) {
197+ return i0 / 2 ;
198+ } else {
199+ return i0 ;
197200 }
198201}
199202
200- fn pair_offset () -> u32 {
201- switch (params . mode ) {
202- case 0 { return 1 ; } // norm
203- case 2 { return params . n_dims / 2 ; } // neox
204- default { return 1 ; }
203+ fn pair_offset (is_neox : bool , is_mrope : bool , is_vision : bool ) -> u32 {
204+ if (is_vision ) {
205+ return params . n_dims ;
206+ } else if (is_neox || is_mrope ) {
207+ return params . n_dims / 2 ;
208+ } else {
209+ return 1 ;
205210 }
206211}
207212
@@ -213,6 +218,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
213218 return ;
214219 }
215220
221+ let is_neox = bool (params . mode & 2 );
222+ let is_mrope = bool (params . mode & 8 );
223+ let is_vision = params . mode == 24 ;
224+
216225 var i = gid . x * 2 ; // start index for this thread
217226 let i3 = i / (params . ne2 * params . ne1 * params . ne0 );
218227 i = i % (params . ne2 * params . ne1 * params . ne0 );
@@ -224,20 +233,49 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
224233 let i_src_row = params . offset_src0 + i3 * params . stride_src03 + i2 * params . stride_src02 + i1 * params . stride_src01 ;
225234 let i_dst_row = params . offset_dst + i3 * params . stride_dst3 + i2 * params . stride_dst2 + i1 * params . stride_dst1 ;
226235
227- if (i0 >= params . n_dims ) {
228- rotate (i_dst_row + i0 , i_dst_row + i0 + 1 , f32 (src0 [i_src_row + i0 ]), f32 (src0 [i_src_row + i0 + 1 ]));
236+ if (i0 >= params . n_dims && ! is_vision ) {
237+ let i_src = i_src_row + i0 ;
238+ let i_dst = i_dst_row + i0 ;
239+ rotate (i_dst , i_dst + 1 , f32 (src0 [i_src ]), f32 (src0 [i_src + 1 ]));
229240 return ;
230241 }
231242
232- let theta_base = f32 (src1 [params . offset_src1 + i2 ]) * pow (params . theta_scale , f32 (i0 )/ 2 .0f );
243+ var theta_base_mult : u32 = 0 ;
244+ var theta_scale_pwr : u32 = i0 / 2 ;
245+ if (is_mrope ) {
246+ let sect_dims = params . sections0 + params . sections1 + params . sections2 + params . sections3 ;
247+ let sec_w = params . sections1 + params . sections0 ;
248+ let sec_e = params . sections2 + sec_w ;
249+ let sector = (i0 / 2 ) % sect_dims ;
250+ if (sector >= params . sections0 && sector < sec_w ) {
251+ theta_base_mult = 1 ;
252+ if (is_vision ) {
253+ theta_scale_pwr = sector - params . sections0 ;
254+ }
255+ } else if (sector >= sec_w && sector < sec_e ) {
256+ theta_base_mult = 2 ;
257+ if (is_vision ) {
258+ theta_scale_pwr = sector - sec_w ;
259+ }
260+ } else if (sector >= sec_e ) {
261+ if (is_vision ) {
262+ theta_scale_pwr = sector - sec_e ;
263+ theta_scale_pwr = (i0 / 2 ) % sec_e ;
264+ }
265+ theta_base_mult = 3 ;
266+ } else if (is_vision ) {
267+ theta_scale_pwr = sector ;
268+ }
269+ }
270+ let theta_base = f32 (src1 [params . offset_src1 + i2 + params . ne2 * theta_base_mult ]) * pow (params . theta_scale , f32 (theta_scale_pwr ));
233271 let thetas = rope_yarn (theta_base / freq_factor (i0 ), i0 );
234272
235- let i_src = i_src_row + pair_base (i0 );
236- let i_dst = i_dst_row + pair_base (i0 );
273+ let i_src = i_src_row + pair_base (i0 , is_neox || is_mrope || is_vision );
274+ let i_dst = i_dst_row + pair_base (i0 , is_neox || is_mrope || is_vision );
237275
238276 let x0 = f32 (src0 [i_src ]);
239- let x1 = f32 (src0 [i_src + pair_offset ()]);
240- rotate (i_dst , i_dst + pair_offset (), x0 * thetas . x - x1 * thetas . y , x0 * thetas . y + x1 * thetas . x );
277+ let x1 = f32 (src0 [i_src + pair_offset (is_neox , is_mrope , is_vision )]);
278+ rotate (i_dst , i_dst + pair_offset (is_neox , is_mrope , is_vision ), x0 * thetas . x - x1 * thetas . y , x0 * thetas . y + x1 * thetas . x );
241279}
242280
243281#end (SHADER )
0 commit comments