@@ -19,6 +19,8 @@ def mrope_kernel(
1919 L : tl .int32 ,
2020 D : tl .int32 ,
2121 HALF : tl .constexpr ,
22+ s_tok : tl .int32 ,
23+ s_ax : tl .int32 ,
2224 q_sb : tl .int32 ,
2325 q_sh : tl .int32 ,
2426 q_sl : tl .int32 ,
@@ -75,7 +77,7 @@ def mrope_kernel(
7577 rot_vals = tl .where (offs < HALF , - rot_vals , rot_vals )
7678
7779 axis_id = tl .load (AXIS_ptr + offs , mask = mask , other = 0 ) # 0,1,2
78- cos_idx = axis_id * ( L * D ) + pid_l * D + offs
80+ cos_idx = pid_l * s_tok + axis_id * s_ax + offs
7981 c = tl .load (COS_ptr + cos_idx , mask = mask , other = 0.0 )
8082 s = tl .load (SIN_ptr + cos_idx , mask = mask , other = 0.0 )
8183
@@ -99,16 +101,19 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
99101 qo_sb , qo_sh , qo_sl , qo_sd = map (int , q_out .stride ())
100102 ko_sb , ko_sh , ko_sl , ko_sd = map (int , k_out .stride ())
101103
102- cos_flat = cos .transpose (0 , 1 ).contiguous ().reshape (- 1 )
103- sin_flat = sin .transpose (0 , 1 ).contiguous ().reshape (- 1 )
104+ token_dim = next (i for i , s in enumerate (cos .shape ) if s == L )
105+ axis_dim = next (i for i , s in enumerate (cos .shape ) if s == 3 )
106+
107+ s_token = int (cos .stride (token_dim ))
108+ s_axis = int (cos .stride (axis_dim ))
104109
105110 grid = (B * (H_q + H_k ), L )
106111
107112 mrope_kernel [grid ](
108113 q ,
109114 k ,
110- cos_flat ,
111- sin_flat ,
115+ cos ,
116+ sin ,
112117 axis_map ,
113118 q_out ,
114119 k_out ,
@@ -118,6 +123,8 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
118123 L ,
119124 D ,
120125 HALF ,
126+ s_token ,
127+ s_axis ,
121128 q_sb ,
122129 q_sh ,
123130 q_sl ,
0 commit comments