1- import math
2- import torch
31import triton
42import triton .language as tl
3+ import torch
54
65
76@triton .jit
8- def rotary_kernel (
7+ def rotary_kernel_tiled (
98 inp_ptr ,
109 cos_ptr ,
1110 sin_ptr ,
@@ -17,38 +16,62 @@ def rotary_kernel(
1716 stride_cos_d ,
1817 stride_sin_l ,
1918 stride_sin_d ,
20- D : tl .constexpr ,
21- HALF_D : tl .constexpr ,
19+ L ,
20+ H ,
21+ D ,
22+ BLOCK_SEQ : tl .constexpr ,
23+ BLOCK_HEAD : tl .constexpr ,
2224 BLOCK_D : tl .constexpr ,
2325):
24- pid_l = tl .program_id (0 ).to (tl .int64 )
25- pid_h = tl .program_id (1 ).to (tl .int64 )
26- pid_blk = tl .program_id (2 ).to (tl .int64 )
26+ pid_head_blk = tl .program_id (0 ) # head tile
27+ pid_seq_blk = tl .program_id (1 ) # seq tile
28+ pid_d_blk = tl .program_id (2 ) # dim tile
29+
30+ offs_h = pid_head_blk * BLOCK_HEAD + tl .arange (0 , BLOCK_HEAD )
31+ offs_l = pid_seq_blk * BLOCK_SEQ + tl .arange (0 , BLOCK_SEQ )
32+ offs_d = pid_d_blk * BLOCK_D + tl .arange (0 , BLOCK_D )
33+
34+ offs_h = offs_h .to (tl .int64 )
35+ offs_l = offs_l .to (tl .int64 )
36+ offs_d = offs_d .to (tl .int64 )
37+
38+ mask_h = offs_h < H
39+ mask_l = offs_l < L
40+ mask_d = offs_d < D
41+
42+ HALF_D = D // 2
43+
44+ l_b = offs_l [:, None , None ]
45+ h_b = offs_h [None , :, None ]
46+ d_b = offs_d [None , None , :]
2747
28- offs_d = tl .arange (0 , BLOCK_D )
29- d = pid_blk * BLOCK_D + offs_d
30- mask = d < D
48+ mask = mask_l [:, None , None ] & mask_h [None , :, None ] & mask_d [None , None , :]
3149
32- base = pid_l * stride_l + pid_h * stride_h
50+ base = l_b * stride_l + h_b * stride_h + d_b * stride_d
51+ x = tl .load (inp_ptr + base , mask = mask , other = 0.0 )
3352
34- in_ptr = inp_ptr + base + d * stride_d
53+ cos_base_2d = offs_l [:, None ] * stride_cos_l + offs_d [None , :] * stride_cos_d
54+ sin_base_2d = offs_l [:, None ] * stride_sin_l + offs_d [None , :] * stride_sin_d
3555
36- cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
37- sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
56+ mask_ld = mask_l [:, None ] & mask_d [None , :]
3857
39- x = tl .load (in_ptr , mask = mask )
40- cos = tl .load (cos_ptr_ , mask = mask )
41- sin = tl .load (sin_ptr_ , mask = mask )
58+ cos_2d = tl .load (cos_ptr + cos_base_2d , mask = mask_ld , other = 0.0 )
59+ sin_2d = tl .load (sin_ptr + sin_base_2d , mask = mask_ld , other = 0.0 )
4260
43- partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
44- partner_ptr = inp_ptr + base + partner_d * stride_d
45- partner_val = tl .load (partner_ptr , mask = mask )
46- rotated = tl .where (d < HALF_D , - partner_val , partner_val )
61+ cos = cos_2d [:, None , :]
62+ sin = sin_2d [:, None , :]
63+
64+ partner_d = tl .where (offs_d < HALF_D , offs_d + HALF_D , offs_d - HALF_D )
65+ partner_d_b = partner_d [None , None , :]
66+
67+ partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
68+ partner_val = tl .load (inp_ptr + partner_base , mask = mask , other = 0.0 )
69+
70+ rotated = tl .where (d_b < HALF_D , - partner_val , partner_val )
4771
4872 y = x * cos + rotated * sin
4973
50- out_ptr_ = out_ptr + base + d
51- tl .store (out_ptr_ , y , mask = mask )
74+ tl .store (out_ptr + base , y , mask = mask )
5275
5376
5477def apply_rotary_pos_emb_triton (
@@ -66,12 +89,23 @@ def apply_rotary_pos_emb_triton(
6689 sin = sin .repeat (1 , 2 ).view (sin .size (0 ), - 1 ).contiguous ().float ()
6790
6891 L , H , D = x .shape
69- HALF_D = D // 2
7092 y = torch .empty_like (x )
7193
72- grid = (L , H , triton .cdiv (D , BLOCK_D ))
94+ BLOCK_SEQ = 16
95+ BLOCK_HEAD = 4
96+
97+ if D >= 128 :
98+ num_warps = 8
99+ else :
100+ num_warps = 4
101+
102+ grid = (
103+ triton .cdiv (H , BLOCK_HEAD ),
104+ triton .cdiv (L , BLOCK_SEQ ),
105+ triton .cdiv (D , BLOCK_D ),
106+ )
73107
74- rotary_kernel [grid ](
108+ rotary_kernel_tiled [grid ](
75109 inp_ptr = x ,
76110 cos_ptr = cos ,
77111 sin_ptr = sin ,
@@ -83,9 +117,13 @@ def apply_rotary_pos_emb_triton(
83117 stride_cos_d = cos .stride (1 ),
84118 stride_sin_l = sin .stride (0 ),
85119 stride_sin_d = sin .stride (1 ),
120+ L = L ,
121+ H = H ,
86122 D = D ,
87- HALF_D = HALF_D ,
123+ BLOCK_SEQ = BLOCK_SEQ ,
124+ BLOCK_HEAD = BLOCK_HEAD ,
88125 BLOCK_D = BLOCK_D ,
126+ num_warps = num_warps ,
89127 )
90128
91129 return y .to (orig_dtype )
0 commit comments