Skip to content

Commit 189c7b1

Browse files
author
sangchengmeng
committed
fix triton_rotary_rope_emb
1 parent 3fe01a9 commit 189c7b1

File tree

1 file changed

+66
-28
lines changed

1 file changed

+66
-28
lines changed
Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import math
2-
import torch
31
import triton
42
import triton.language as tl
3+
import torch
54

65

76
@triton.jit
8-
def rotary_kernel(
7+
def rotary_kernel_tiled(
98
inp_ptr,
109
cos_ptr,
1110
sin_ptr,
@@ -17,38 +16,62 @@ def rotary_kernel(
1716
stride_cos_d,
1817
stride_sin_l,
1918
stride_sin_d,
20-
D: tl.constexpr,
21-
HALF_D: tl.constexpr,
19+
L,
20+
H,
21+
D,
22+
BLOCK_SEQ: tl.constexpr,
23+
BLOCK_HEAD: tl.constexpr,
2224
BLOCK_D: tl.constexpr,
2325
):
24-
pid_l = tl.program_id(0).to(tl.int64)
25-
pid_h = tl.program_id(1).to(tl.int64)
26-
pid_blk = tl.program_id(2).to(tl.int64)
26+
pid_head_blk = tl.program_id(0) # head tile
27+
pid_seq_blk = tl.program_id(1) # seq tile
28+
pid_d_blk = tl.program_id(2) # dim tile
29+
30+
offs_h = pid_head_blk * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
31+
offs_l = pid_seq_blk * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
32+
offs_d = pid_d_blk * BLOCK_D + tl.arange(0, BLOCK_D)
33+
34+
offs_h = offs_h.to(tl.int64)
35+
offs_l = offs_l.to(tl.int64)
36+
offs_d = offs_d.to(tl.int64)
37+
38+
mask_h = offs_h < H
39+
mask_l = offs_l < L
40+
mask_d = offs_d < D
41+
42+
HALF_D = D // 2
43+
44+
l_b = offs_l[:, None, None]
45+
h_b = offs_h[None, :, None]
46+
d_b = offs_d[None, None, :]
2747

28-
offs_d = tl.arange(0, BLOCK_D)
29-
d = pid_blk * BLOCK_D + offs_d
30-
mask = d < D
48+
mask = mask_l[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :]
3149

32-
base = pid_l * stride_l + pid_h * stride_h
50+
base = l_b * stride_l + h_b * stride_h + d_b * stride_d
51+
x = tl.load(inp_ptr + base, mask=mask, other=0.0)
3352

34-
in_ptr = inp_ptr + base + d * stride_d
53+
cos_base_2d = offs_l[:, None] * stride_cos_l + offs_d[None, :] * stride_cos_d
54+
sin_base_2d = offs_l[:, None] * stride_sin_l + offs_d[None, :] * stride_sin_d
3555

36-
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
37-
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
56+
mask_ld = mask_l[:, None] & mask_d[None, :]
3857

39-
x = tl.load(in_ptr, mask=mask)
40-
cos = tl.load(cos_ptr_, mask=mask)
41-
sin = tl.load(sin_ptr_, mask=mask)
58+
cos_2d = tl.load(cos_ptr + cos_base_2d, mask=mask_ld, other=0.0)
59+
sin_2d = tl.load(sin_ptr + sin_base_2d, mask=mask_ld, other=0.0)
4260

43-
partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
44-
partner_ptr = inp_ptr + base + partner_d * stride_d
45-
partner_val = tl.load(partner_ptr, mask=mask)
46-
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
61+
cos = cos_2d[:, None, :]
62+
sin = sin_2d[:, None, :]
63+
64+
partner_d = tl.where(offs_d < HALF_D, offs_d + HALF_D, offs_d - HALF_D)
65+
partner_d_b = partner_d[None, None, :]
66+
67+
partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
68+
partner_val = tl.load(inp_ptr + partner_base, mask=mask, other=0.0)
69+
70+
rotated = tl.where(d_b < HALF_D, -partner_val, partner_val)
4771

4872
y = x * cos + rotated * sin
4973

50-
out_ptr_ = out_ptr + base + d
51-
tl.store(out_ptr_, y, mask=mask)
74+
tl.store(out_ptr + base, y, mask=mask)
5275

5376

5477
def apply_rotary_pos_emb_triton(
@@ -66,12 +89,23 @@ def apply_rotary_pos_emb_triton(
6689
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()
6790

6891
L, H, D = x.shape
69-
HALF_D = D // 2
7092
y = torch.empty_like(x)
7193

72-
grid = (L, H, triton.cdiv(D, BLOCK_D))
94+
BLOCK_SEQ = 16
95+
BLOCK_HEAD = 4
96+
97+
if D >= 128:
98+
num_warps = 8
99+
else:
100+
num_warps = 4
101+
102+
grid = (
103+
triton.cdiv(H, BLOCK_HEAD),
104+
triton.cdiv(L, BLOCK_SEQ),
105+
triton.cdiv(D, BLOCK_D),
106+
)
73107

74-
rotary_kernel[grid](
108+
rotary_kernel_tiled[grid](
75109
inp_ptr=x,
76110
cos_ptr=cos,
77111
sin_ptr=sin,
@@ -83,9 +117,13 @@ def apply_rotary_pos_emb_triton(
83117
stride_cos_d=cos.stride(1),
84118
stride_sin_l=sin.stride(0),
85119
stride_sin_d=sin.stride(1),
120+
L=L,
121+
H=H,
86122
D=D,
87-
HALF_D=HALF_D,
123+
BLOCK_SEQ=BLOCK_SEQ,
124+
BLOCK_HEAD=BLOCK_HEAD,
88125
BLOCK_D=BLOCK_D,
126+
num_warps=num_warps,
89127
)
90128

91129
return y.to(orig_dtype)

0 commit comments

Comments
 (0)