@@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
221221
222222 let is_neox = bool (params . mode & 2 );
223223 let is_mrope = bool (params . mode & 8 );
224+ let is_imrope = params . mode == 40 ;
224225 let is_vision = params . mode == 24 ;
225226
226227 var i = gid . x * 2 ; // start index for this thread
@@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
248249 let sec_w = params . sections1 + params . sections0 ;
249250 let sec_e = params . sections2 + sec_w ;
250251 let sector = (i0 / 2 ) % sect_dims ;
251- if (sector >= params . sections0 && sector < sec_w ) {
252- theta_base_mult = 1 ;
253- if (is_vision ) {
254- theta_scale_pwr = sector - params . sections0 ;
255- }
256- } else if (sector >= sec_w && sector < sec_e ) {
257- theta_base_mult = 2 ;
258- if (is_vision ) {
259- theta_scale_pwr = sector - sec_w ;
260- }
261- } else if (sector >= sec_e ) {
262- if (is_vision ) {
263- theta_scale_pwr = sector - sec_e ;
264- theta_scale_pwr = (i0 / 2 ) % sec_e ;
265- }
266- theta_base_mult = 3 ;
267- } else if (is_vision ) {
268- theta_scale_pwr = sector ;
252+ if (is_imrope ) {
253+ if (sector % 3 == 1 && sector < 3 * params . sections1 ) {
254+ theta_base_mult = 1 ;
255+ } else if (sector % 3 == 2 && sector < 3 * params . sections2 ) {
256+ theta_base_mult = 2 ;
257+ } else if (sector % 3 == 0 && sector < 3 * params . sections0 ) {
258+ theta_base_mult = 0 ;
259+ } else {
260+ theta_base_mult = 3 ;
261+ }
262+ } else {
263+ if (sector >= params . sections0 && sector < sec_w ) {
264+ theta_base_mult = 1 ;
265+ if (is_vision ) {
266+ theta_scale_pwr = sector - params . sections0 ;
267+ }
268+ } else if (sector >= sec_w && sector < sec_e ) {
269+ theta_base_mult = 2 ;
270+ if (is_vision ) {
271+ theta_scale_pwr = sector - sec_w ;
272+ }
273+ } else if (sector >= sec_e ) {
274+ if (is_vision ) {
275+ theta_scale_pwr = sector - sec_e ;
276+ theta_scale_pwr = (i0 / 2 ) % sec_e ;
277+ }
278+ theta_base_mult = 3 ;
279+ } else if (is_vision ) {
280+ theta_scale_pwr = sector ;
281+ }
269282 }
270283 }
271284 let theta_base = f32 (src1 [params . offset_src1 + i2 + params . ne2 * theta_base_mult ]) * pow (params . theta_scale , f32 (theta_scale_pwr ));
0 commit comments