|
| 1 | +import time |
| 2 | +import torch |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def mrope_kernel_combined( |
| 9 | + Q_ptr, |
| 10 | + K_ptr, |
| 11 | + COS_ptr, |
| 12 | + SIN_ptr, |
| 13 | + AXIS_MAP_ptr, |
| 14 | + Q_out_ptr, |
| 15 | + K_out_ptr, |
| 16 | + B: tl.int32, |
| 17 | + H_q: tl.int32, |
| 18 | + H_k: tl.int32, |
| 19 | + L: tl.int32, |
| 20 | + D: tl.int32, |
| 21 | + HALF: tl.int32, |
| 22 | + BLOCK_D: tl.constexpr, |
| 23 | +): |
| 24 | + total_h = H_q + H_k |
| 25 | + |
| 26 | + pid_bh = tl.program_id(0) |
| 27 | + pid_l = tl.program_id(1) |
| 28 | + |
| 29 | + b = pid_bh // total_h |
| 30 | + head_local = pid_bh - b * total_h |
| 31 | + |
| 32 | + # decide whether this head comes from q or k |
| 33 | + is_q = head_local < H_q |
| 34 | + head_q = head_local |
| 35 | + head_k = head_local - H_q |
| 36 | + |
| 37 | + base_ptr = tl.where(is_q, Q_ptr, K_ptr) |
| 38 | + out_ptr = tl.where(is_q, Q_out_ptr, K_out_ptr) |
| 39 | + h_sub = tl.where(is_q, head_q, head_k) |
| 40 | + H_sub = tl.where(is_q, H_q, H_k) |
| 41 | + |
| 42 | + # base offset for (b, h_sub, pid_l) |
| 43 | + base = ((b * H_sub + h_sub) * L + pid_l) * D |
| 44 | + |
| 45 | + offs = tl.arange(0, BLOCK_D) |
| 46 | + idx = base + offs |
| 47 | + mask = offs < D |
| 48 | + |
| 49 | + vals = tl.load(base_ptr + idx, mask=mask, other=0.0) |
| 50 | + axis_id = tl.load(AXIS_MAP_ptr + offs, mask=mask, other=0) |
| 51 | + axis_id_b = b * 3 + axis_id |
| 52 | + |
| 53 | + seq_off = pid_l * D |
| 54 | + cos_idx = axis_id_b * (L * D) + seq_off + offs |
| 55 | + c = tl.load(COS_ptr + cos_idx, mask=mask, other=0.0) |
| 56 | + s = tl.load(SIN_ptr + cos_idx, mask=mask, other=0.0) |
| 57 | + |
| 58 | + # rotate_half |
| 59 | + rot_idx = tl.where(offs < HALF, idx + HALF, idx - HALF) |
| 60 | + rot_vals = tl.load(base_ptr + rot_idx, mask=mask, other=0.0) |
| 61 | + sign = tl.where(offs < HALF, -1.0, 1.0) |
| 62 | + rot_vals *= sign |
| 63 | + |
| 64 | + out_vals = vals * c + rot_vals * s |
| 65 | + tl.store(out_ptr + idx, out_vals, mask=mask) |
| 66 | + |
| 67 | + |
| 68 | +def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mrope_section): |
| 69 | + B, H_q, L, D = q.shape |
| 70 | + H_k = k.shape[1] |
| 71 | + |
| 72 | + # build axis_map 0/1/2 label per feature dim |
| 73 | + axis_map = [] |
| 74 | + for i, n in enumerate(mrope_section * 2): |
| 75 | + axis_map += [i % 3] * n |
| 76 | + axis_map = torch.tensor(axis_map, dtype=torch.int32, device=q.device) |
| 77 | + |
| 78 | + cos_flat = cos.transpose(0, 1).expand(B, 3, L, D).contiguous().reshape(-1) |
| 79 | + sin_flat = sin.transpose(0, 1).expand(B, 3, L, D).contiguous().reshape(-1) |
| 80 | + |
| 81 | + q_out = torch.empty_like(q) |
| 82 | + k_out = torch.empty_like(k) |
| 83 | + |
| 84 | + grid = (B * (H_q + H_k), L) |
| 85 | + mrope_kernel_combined[grid]( |
| 86 | + q, |
| 87 | + k, |
| 88 | + cos_flat, |
| 89 | + sin_flat, |
| 90 | + axis_map, |
| 91 | + q_out, |
| 92 | + k_out, |
| 93 | + B, |
| 94 | + H_q, |
| 95 | + H_k, |
| 96 | + L, |
| 97 | + D, |
| 98 | + D // 2, |
| 99 | + BLOCK_D=128, |
| 100 | + ) |
| 101 | + return q_out, k_out |
| 102 | + |
| 103 | + |
| 104 | +# ---------------- test ---------------- # |
| 105 | +def test(): |
| 106 | + |
| 107 | + # torch实现 |
| 108 | + def rotate_half(x: torch.Tensor): |
| 109 | + x1 = x[..., : x.shape[-1] // 2] |
| 110 | + x2 = x[..., x.shape[-1] // 2 :] |
| 111 | + return torch.cat((-x2, x1), dim=-1) |
| 112 | + |
| 113 | + def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): |
| 114 | + chunks = mrope_section * 2 |
| 115 | + cos_embed = torch.cat( |
| 116 | + [m[i % 3] for i, m in enumerate(cos.split(chunks, dim=-1))], |
| 117 | + dim=-1, |
| 118 | + ).unsqueeze(unsqueeze_dim) |
| 119 | + sin_embed = torch.cat( |
| 120 | + [m[i % 3] for i, m in enumerate(sin.split(chunks, dim=-1))], |
| 121 | + dim=-1, |
| 122 | + ).unsqueeze(unsqueeze_dim) |
| 123 | + |
| 124 | + q_out = q * cos_embed + rotate_half(q) * sin_embed |
| 125 | + k_out = k * cos_embed + rotate_half(k) * sin_embed |
| 126 | + return q_out, k_out |
| 127 | + |
| 128 | + B, H_q, H_k, L, D = 1, 28, 4, 16384, 128 |
| 129 | + mrope_section = [16, 24, 24] |
| 130 | + torch.manual_seed(0) |
| 131 | + device = "cuda" |
| 132 | + |
| 133 | + q = torch.rand(B, H_q, L, D, dtype=torch.float32, device=device) |
| 134 | + k = torch.rand(B, H_k, L, D, dtype=torch.float32, device=device) |
| 135 | + cos = torch.rand(3, 1, L, D, dtype=torch.float32, device=device) |
| 136 | + sin = torch.rand(3, 1, L, D, dtype=torch.float32, device=device) |
| 137 | + |
| 138 | + # 精度对比 |
| 139 | + ref_q, ref_k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1) |
| 140 | + |
| 141 | + torch.cuda.synchronize() |
| 142 | + out_q, out_k = mrope_triton(q, k, cos, sin, mrope_section) |
| 143 | + torch.cuda.synchronize() |
| 144 | + |
| 145 | + err_q = (out_q - ref_q).abs().max().item() |
| 146 | + err_k = (out_k - ref_k).abs().max().item() |
| 147 | + print(f"abs‑max error q:{err_q:.6f}, k:{err_k:.6f}") |
| 148 | + |
| 149 | + assert err_q < 1e-2 and err_k < 1e-2 |
| 150 | + |
| 151 | + # 速度对比 |
| 152 | + n_iter = 100 |
| 153 | + e0 = torch.cuda.Event(enable_timing=True) |
| 154 | + e1 = torch.cuda.Event(enable_timing=True) |
| 155 | + |
| 156 | + e0.record() |
| 157 | + for _ in range(n_iter): |
| 158 | + _ = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1) |
| 159 | + e1.record() |
| 160 | + torch.cuda.synchronize() |
| 161 | + t_ref = e0.elapsed_time(e1) / n_iter |
| 162 | + |
| 163 | + e0.record() |
| 164 | + for _ in range(n_iter): |
| 165 | + _ = mrope_triton(q, k, cos, sin, mrope_section) |
| 166 | + e1.record() |
| 167 | + torch.cuda.synchronize() |
| 168 | + t_tri = e0.elapsed_time(e1) / n_iter |
| 169 | + |
| 170 | + print(f"torch {t_ref:.2f} ms/iter") |
| 171 | + print(f"triton {t_tri:.2f} ms/iter") |
| 172 | + |
0 commit comments