|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy |
| 6 | + |
| 7 | + |
| 8 | +def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int): |
| 9 | + """ |
| 10 | + Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors. |
| 11 | +
|
| 12 | + Supports: |
| 13 | + - complex freqs: (..., head_dim_half) complex -> real/imag |
| 14 | + - packed freqs: (..., 2*head_dim_half) real -> split into real/imag |
| 15 | + """ |
| 16 | + if freqs_cis.is_complex(): |
| 17 | + freqs_real = freqs_cis.real |
| 18 | + freqs_imag = freqs_cis.imag |
| 19 | + else: |
| 20 | + if freqs_cis.shape[-1] == 2 * head_dim_half: |
| 21 | + freqs_real = freqs_cis[..., :head_dim_half] |
| 22 | + freqs_imag = freqs_cis[..., head_dim_half:] |
| 23 | + else: |
| 24 | + raise ValueError( |
| 25 | + f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, " |
| 26 | + f"expected last dim = {2 * head_dim_half}" |
| 27 | + ) |
| 28 | + |
| 29 | + if freqs_real.shape[-1] != head_dim_half: |
| 30 | + raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})") |
| 31 | + |
| 32 | + # Flatten leading dims -> (N, head_dim_half) |
| 33 | + freqs_real = freqs_real.reshape(-1, head_dim_half) |
| 34 | + freqs_imag = freqs_imag.reshape(-1, head_dim_half) |
| 35 | + |
| 36 | + # Broadcast/slice to (seq_len, head_dim_half) |
| 37 | + if freqs_real.shape[0] < seq_len: |
| 38 | + if freqs_real.shape[0] == 1: |
| 39 | + freqs_real = freqs_real.expand(seq_len, -1) |
| 40 | + freqs_imag = freqs_imag.expand(seq_len, -1) |
| 41 | + else: |
| 42 | + raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}") |
| 43 | + elif freqs_real.shape[0] > seq_len: |
| 44 | + freqs_real = freqs_real[:seq_len] |
| 45 | + freqs_imag = freqs_imag[:seq_len] |
| 46 | + |
| 47 | + return freqs_real, freqs_imag |
| 48 | + |
| 49 | + |
| 50 | +def _cast_and_contiguous(q, k, freqs_real, freqs_imag): |
| 51 | + # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf |
| 52 | + compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype |
| 53 | + |
| 54 | + if k.dtype != q.dtype: |
| 55 | + k = k.to(q.dtype) |
| 56 | + |
| 57 | + q = q.to(compute_dtype).contiguous() |
| 58 | + k = k.to(compute_dtype).contiguous() |
| 59 | + freqs_real = freqs_real.to(compute_dtype).contiguous() |
| 60 | + freqs_imag = freqs_imag.to(compute_dtype).contiguous() |
| 61 | + return q, k, freqs_real, freqs_imag, compute_dtype |
| 62 | + |
| 63 | + |
| 64 | +@triton.jit |
| 65 | +def _triton_llama4_rope_npu( |
| 66 | + q_ptr, |
| 67 | + k_ptr, |
| 68 | + freqs_real_ptr, |
| 69 | + freqs_imag_ptr, |
| 70 | + q_row_stride, |
| 71 | + k_row_stride, |
| 72 | + q_head_stride, |
| 73 | + k_head_stride, |
| 74 | + freqs_row_stride, |
| 75 | + sl, |
| 76 | + bs: tl.constexpr, |
| 77 | + n_qh: tl.constexpr, |
| 78 | + n_kh: tl.constexpr, |
| 79 | + hd: tl.constexpr, |
| 80 | + BLOCK_Q: tl.constexpr, |
| 81 | + BLOCK_K: tl.constexpr, |
| 82 | + imag_sign: tl.constexpr, |
| 83 | +): |
| 84 | + """ |
| 85 | + Llama4 RoPE on Ascend NPU for interleaved complex layout: |
| 86 | + - q/k shape: (bs, sl, n_heads, hd) |
| 87 | + - last dim layout: [real0, imag0, real1, imag1, ...] |
| 88 | + - freqs_real/imag: (sl, hd//2) |
| 89 | + """ |
| 90 | + pid = tl.program_id(0).to(tl.int64) |
| 91 | + batch_idx = pid // sl |
| 92 | + seq_idx = pid % sl |
| 93 | + |
| 94 | + if batch_idx >= bs: |
| 95 | + return |
| 96 | + |
| 97 | + q_base = q_ptr + pid * q_row_stride |
| 98 | + k_base = k_ptr + pid * k_row_stride |
| 99 | + |
| 100 | + freq_base = seq_idx * freqs_row_stride |
| 101 | + hd_idx = tl.arange(0, hd) |
| 102 | + hd_mask = hd_idx < (hd) |
| 103 | + |
| 104 | + freq_idx = tl.arange(0, hd // 2) |
| 105 | + freq_mask = freq_idx < (hd // 2) |
| 106 | + |
| 107 | + freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) |
| 108 | + freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign |
| 109 | + |
| 110 | + # Q heads (chunked for UB) |
| 111 | + for qh_block in range(0, n_qh, BLOCK_Q): |
| 112 | + qh_idx = tl.arange(0, BLOCK_Q) + qh_block |
| 113 | + qh_mask = qh_idx < n_qh |
| 114 | + block_mask = qh_mask[:, None] & hd_mask[None, :] |
| 115 | + |
| 116 | + head_ptr = q_base + qh_idx[:, None] * q_head_stride |
| 117 | + |
| 118 | + q_pair = tl.load( |
| 119 | + head_ptr + hd_idx[None, :], |
| 120 | + mask=block_mask, |
| 121 | + other=0.0, |
| 122 | + ) |
| 123 | + q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True) |
| 124 | + q_real, q_imag = tl.split(q_pair) |
| 125 | + |
| 126 | + new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag)) |
| 127 | + new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real) |
| 128 | + new_q_pair = tl.interleave(new_real, new_imag) |
| 129 | + |
| 130 | + tl.store(head_ptr + hd_idx[None, :], new_q_pair, mask=block_mask) |
| 131 | + |
| 132 | + # K heads (chunked for UB) |
| 133 | + for kh_block in range(0, n_kh, BLOCK_K): |
| 134 | + kh_idx = tl.arange(0, BLOCK_K) + kh_block |
| 135 | + kh_mask = kh_idx < n_kh |
| 136 | + block_mask = kh_mask[:, None] & hd_mask[None, :] |
| 137 | + |
| 138 | + head_ptr = k_base + kh_idx[:, None] * k_head_stride |
| 139 | + |
| 140 | + k_pair = tl.load( |
| 141 | + head_ptr + hd_idx[None, :], |
| 142 | + mask=block_mask, |
| 143 | + other=0.0, |
| 144 | + ) |
| 145 | + |
| 146 | + k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True) |
| 147 | + k_real, k_imag = tl.split(k_pair) |
| 148 | + |
| 149 | + new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag)) |
| 150 | + new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real) |
| 151 | + new_k_pair = tl.interleave(new_real, new_imag) |
| 152 | + |
| 153 | + tl.store(head_ptr + hd_idx[None, :], new_k_pair, mask=block_mask) |
| 154 | + |
| 155 | + |
| 156 | +def llama4_rope_forward(q, k, freqs_cis): |
| 157 | + """ |
| 158 | + Ascend NPU implementation of Llama4 RoPE. |
| 159 | +
|
| 160 | + q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout. |
| 161 | + freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)). |
| 162 | + """ |
| 163 | + original_dtype = q.dtype |
| 164 | + |
| 165 | + bs, sl, n_qh, hd = q.shape |
| 166 | + _, _, n_kh, _ = k.shape |
| 167 | + if hd % 2 != 0: |
| 168 | + raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}") |
| 169 | + hd_half = hd // 2 |
| 170 | + |
| 171 | + freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half) |
| 172 | + q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag) |
| 173 | + |
| 174 | + # UB tiling strategy: tile heads dimension only |
| 175 | + dtype_size = q.element_size() |
| 176 | + shapes = ((n_qh, hd), (n_kh, hd)) |
| 177 | + tile_shapes = compute_default_tiling_strategy( |
| 178 | + safety_margin=0.90, |
| 179 | + dtype_size=dtype_size, |
| 180 | + memory_multiplier=12.0, |
| 181 | + shapes=shapes, |
| 182 | + tiling_dims=(0, 0), |
| 183 | + ) |
| 184 | + |
| 185 | + if tile_shapes is not None and len(tile_shapes) == len(shapes): |
| 186 | + q_tile_shape, k_tile_shape = tile_shapes |
| 187 | + BLOCK_Q, _ = q_tile_shape |
| 188 | + BLOCK_K, _ = k_tile_shape |
| 189 | + else: |
| 190 | + BLOCK_Q = triton.next_power_of_2(n_qh) |
| 191 | + BLOCK_K = triton.next_power_of_2(n_kh) |
| 192 | + |
| 193 | + n_row = bs * sl |
| 194 | + |
| 195 | + _triton_llama4_rope_npu[(n_row,)]( |
| 196 | + q, |
| 197 | + k, |
| 198 | + freqs_real, |
| 199 | + freqs_imag, |
| 200 | + q.stride(1), |
| 201 | + k.stride(1), |
| 202 | + q.stride(2), |
| 203 | + k.stride(2), |
| 204 | + freqs_real.stride(0), |
| 205 | + sl, |
| 206 | + bs, |
| 207 | + n_qh, |
| 208 | + n_kh, |
| 209 | + hd, |
| 210 | + BLOCK_Q, |
| 211 | + BLOCK_K, |
| 212 | + imag_sign=1.0, |
| 213 | + ) |
| 214 | + |
| 215 | + if compute_dtype != original_dtype: |
| 216 | + q = q.to(original_dtype) |
| 217 | + k = k.to(original_dtype) |
| 218 | + return q, k |
| 219 | + |
| 220 | + |
| 221 | +def llama4_rope_backward(dq, dk, freqs_cis): |
| 222 | + """ |
| 223 | + Ascend NPU implementation of Llama4 RoPE. |
| 224 | +
|
| 225 | + q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout. |
| 226 | + freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)). |
| 227 | + """ |
| 228 | + original_dtype = dq.dtype |
| 229 | + |
| 230 | + bs, sl, n_qh, hd = dq.shape |
| 231 | + _, _, n_kh, _ = dk.shape |
| 232 | + if hd % 2 != 0: |
| 233 | + raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}") |
| 234 | + hd_half = hd // 2 |
| 235 | + |
| 236 | + freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half) |
| 237 | + dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag) |
| 238 | + |
| 239 | + # UB tiling strategy: tile heads dimension only |
| 240 | + dtype_size = dq.element_size() |
| 241 | + shapes = ((n_qh, hd), (n_kh, hd)) |
| 242 | + tile_shapes = compute_default_tiling_strategy( |
| 243 | + safety_margin=0.90, |
| 244 | + dtype_size=dtype_size, |
| 245 | + memory_multiplier=12.0, |
| 246 | + shapes=shapes, |
| 247 | + tiling_dims=(0, 0), |
| 248 | + ) |
| 249 | + |
| 250 | + if tile_shapes is not None and len(tile_shapes) == len(shapes): |
| 251 | + q_tile_shape, k_tile_shape = tile_shapes |
| 252 | + BLOCK_Q, _ = q_tile_shape |
| 253 | + BLOCK_K, _ = k_tile_shape |
| 254 | + else: |
| 255 | + BLOCK_Q = triton.next_power_of_2(n_qh) |
| 256 | + BLOCK_K = triton.next_power_of_2(n_kh) |
| 257 | + |
| 258 | + n_row = bs * sl |
| 259 | + |
| 260 | + _triton_llama4_rope_npu[(n_row,)]( |
| 261 | + dq, |
| 262 | + dk, |
| 263 | + freqs_real, |
| 264 | + freqs_imag, |
| 265 | + dq.stride(1), |
| 266 | + dk.stride(1), |
| 267 | + dq.stride(2), |
| 268 | + dk.stride(2), |
| 269 | + freqs_real.stride(0), |
| 270 | + sl, |
| 271 | + bs, |
| 272 | + n_qh, |
| 273 | + n_kh, |
| 274 | + hd, |
| 275 | + BLOCK_Q, |
| 276 | + BLOCK_K, |
| 277 | + imag_sign=-1.0, |
| 278 | + ) |
| 279 | + |
| 280 | + if compute_dtype != original_dtype: |
| 281 | + dq = dq.to(original_dtype) |
| 282 | + dk = dk.to(original_dtype) |
| 283 | + return dq, dk |
| 284 | + |
| 285 | + |
| 286 | +class LigerLlama4RopeFunction(torch.autograd.Function): |
| 287 | + @staticmethod |
| 288 | + def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None): |
| 289 | + # BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility |
| 290 | + q_out, k_out = llama4_rope_forward(q, k, freqs_cis) |
| 291 | + ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis) |
| 292 | + return q_out, k_out |
| 293 | + |
| 294 | + @staticmethod |
| 295 | + def backward(ctx, dq, dk): |
| 296 | + (freqs_cis,) = ctx.saved_tensors |
| 297 | + dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis) |
| 298 | + return dq_out, dk_out, None, None |
0 commit comments