|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy |
| 6 | + |
| 7 | + |
| 8 | +@triton.jit |
| 9 | +def _triton_qwen2vl_mrope_npu( |
| 10 | + q_ptr, |
| 11 | + q_row_stride, |
| 12 | + k_ptr, |
| 13 | + k_row_stride, |
| 14 | + cos, |
| 15 | + sin, |
| 16 | + sl, |
| 17 | + bs: tl.constexpr, |
| 18 | + n_qh: tl.constexpr, |
| 19 | + n_kh: tl.constexpr, |
| 20 | + hd: tl.constexpr, |
| 21 | + mrope_section_t: tl.constexpr, |
| 22 | + mrope_section_h: tl.constexpr, |
| 23 | + BLOCK_Q: tl.constexpr, |
| 24 | + BLOCK_K: tl.constexpr, |
| 25 | + BACKWARD_PASS: tl.constexpr = False, |
| 26 | +): |
| 27 | + pid = tl.program_id(0).to(tl.int64) |
| 28 | + |
| 29 | + t_end = mrope_section_t |
| 30 | + h_end = t_end + mrope_section_h |
| 31 | + |
| 32 | + t_cos = cos + pid * hd |
| 33 | + h_cos = t_cos + bs * sl * hd |
| 34 | + w_cos = h_cos + bs * sl * hd |
| 35 | + t_sin = sin + pid * hd |
| 36 | + h_sin = t_sin + bs * sl * hd |
| 37 | + w_sin = h_sin + bs * sl * hd |
| 38 | + |
| 39 | + q_base = q_ptr + pid * q_row_stride |
| 40 | + k_base = k_ptr + pid * k_row_stride |
| 41 | + |
| 42 | + d_idx = tl.arange(0, hd // 2) |
| 43 | + d_mask = d_idx < (hd // 2) |
| 44 | + |
| 45 | + pos_mask_t = d_idx < t_end |
| 46 | + pos_mask_h = (d_idx >= t_end) & (d_idx < h_end) |
| 47 | + |
| 48 | + text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0) |
| 49 | + text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0) |
| 50 | + height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0) |
| 51 | + height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0) |
| 52 | + width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0) |
| 53 | + width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0) |
| 54 | + |
| 55 | + cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals)) |
| 56 | + sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals)) |
| 57 | + |
| 58 | + for qh_block in range(0, n_qh, BLOCK_Q): |
| 59 | + qh_idx = tl.arange(0, BLOCK_Q) + qh_block |
| 60 | + qh_mask = qh_idx < n_qh |
| 61 | + |
| 62 | + block_mask = qh_mask[:, None] & d_mask[None, :] |
| 63 | + offsets = qh_idx[:, None] * hd + d_idx[None, :] |
| 64 | + |
| 65 | + q_left = tl.load(q_base + offsets, mask=block_mask, other=0) |
| 66 | + q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0) |
| 67 | + |
| 68 | + if not BACKWARD_PASS: |
| 69 | + new_left = q_left * cos_vals - q_right * sin_vals |
| 70 | + new_right = q_right * cos_vals + q_left * sin_vals |
| 71 | + else: |
| 72 | + new_left = q_left * cos_vals + q_right * sin_vals |
| 73 | + new_right = q_right * cos_vals - q_left * sin_vals |
| 74 | + |
| 75 | + tl.store(q_base + offsets, new_left, mask=block_mask) |
| 76 | + tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask) |
| 77 | + |
| 78 | + for kh_block in range(0, n_kh, BLOCK_K): |
| 79 | + kh_idx = tl.arange(0, BLOCK_K) + kh_block |
| 80 | + kh_mask = kh_idx < n_kh |
| 81 | + |
| 82 | + block_mask = kh_mask[:, None] & d_mask[None, :] |
| 83 | + offsets = kh_idx[:, None] * hd + d_idx[None, :] |
| 84 | + |
| 85 | + k_left = tl.load(k_base + offsets, mask=block_mask, other=0) |
| 86 | + k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0) |
| 87 | + |
| 88 | + if not BACKWARD_PASS: |
| 89 | + new_left = k_left * cos_vals - k_right * sin_vals |
| 90 | + new_right = k_right * cos_vals + k_left * sin_vals |
| 91 | + else: |
| 92 | + new_left = k_left * cos_vals + k_right * sin_vals |
| 93 | + new_right = k_right * cos_vals - k_left * sin_vals |
| 94 | + |
| 95 | + tl.store(k_base + offsets, new_left, mask=block_mask) |
| 96 | + tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask) |
| 97 | + |
| 98 | + |
| 99 | +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): |
| 100 | + # transpose it back to the physical shape because Triton looks at the physical storage |
| 101 | + # note: q and k are incontiguous before the transformation and will become contiguous after transpose |
| 102 | + q = q.transpose(1, 2) |
| 103 | + k = k.transpose(1, 2) |
| 104 | + |
| 105 | + batch_size, seq_len, n_q_head, head_dim = q.shape |
| 106 | + n_kv_head = k.shape[2] |
| 107 | + pad_hd = triton.next_power_of_2(head_dim) |
| 108 | + pad_n_q_head = triton.next_power_of_2(n_q_head) |
| 109 | + pad_n_kv_head = triton.next_power_of_2(n_kv_head) |
| 110 | + |
| 111 | + n_row = batch_size * seq_len |
| 112 | + |
| 113 | + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous |
| 114 | + q = q.contiguous() |
| 115 | + k = k.contiguous() |
| 116 | + cos = cos.contiguous() |
| 117 | + sin = sin.contiguous() |
| 118 | + |
| 119 | + # Compute tiling strategy based on UB capacity |
| 120 | + dtype_size = q.element_size() |
| 121 | + # MROPE forward tiling strategy: |
| 122 | + # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each |
| 123 | + # - In q heads loop (peak memory): |
| 124 | + # * q_left: BLOCK_Q * (pad_hd // 2) elements |
| 125 | + # * q_right: BLOCK_Q * (pad_hd // 2) elements |
| 126 | + # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result) |
| 127 | + # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result) |
| 128 | + # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements |
| 129 | + # - In k heads loop (peak memory): |
| 130 | + # * k_left: BLOCK_K * (pad_hd // 2) elements |
| 131 | + # * k_right: BLOCK_K * (pad_hd // 2) elements |
| 132 | + # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result) |
| 133 | + # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result) |
| 134 | + # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements |
| 135 | + # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case |
| 136 | + # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements |
| 137 | + # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits |
| 138 | + # - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits |
| 139 | + # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits |
| 140 | + # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) |
| 141 | + # - tiling_dims: (0, 0) means first dimension of each shape can be tiled |
| 142 | + # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd)) |
| 143 | + shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) |
| 144 | + tile_shapes = compute_default_tiling_strategy( |
| 145 | + safety_margin=0.90, |
| 146 | + dtype_size=dtype_size, |
| 147 | + memory_multiplier=3.0, |
| 148 | + shapes=shapes, |
| 149 | + tiling_dims=(0, 0), |
| 150 | + ) |
| 151 | + |
| 152 | + if tile_shapes is not None and len(tile_shapes) == len(shapes): |
| 153 | + # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd)) |
| 154 | + q_tile_shape, k_tile_shape = tile_shapes |
| 155 | + BLOCK_Q, _ = q_tile_shape |
| 156 | + BLOCK_K, _ = k_tile_shape |
| 157 | + else: |
| 158 | + # Fallback to conservative defaults |
| 159 | + BLOCK_Q = triton.next_power_of_2(pad_n_q_head) |
| 160 | + BLOCK_K = triton.next_power_of_2(pad_n_kv_head) |
| 161 | + _triton_qwen2vl_mrope_npu[(n_row,)]( |
| 162 | + q, |
| 163 | + q.stride(1), |
| 164 | + k, |
| 165 | + k.stride(1), |
| 166 | + cos, |
| 167 | + sin, |
| 168 | + seq_len, |
| 169 | + batch_size, |
| 170 | + n_q_head, |
| 171 | + n_kv_head, |
| 172 | + head_dim, |
| 173 | + mrope_section[0], |
| 174 | + mrope_section[1], |
| 175 | + BLOCK_Q, |
| 176 | + BLOCK_K, |
| 177 | + BACKWARD_PASS=False, |
| 178 | + ) |
| 179 | + return q.transpose(1, 2), k.transpose(1, 2), cos, sin |
| 180 | + |
| 181 | + |
| 182 | +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): |
| 183 | + dq = dq.transpose(1, 2) |
| 184 | + dk = dk.transpose(1, 2) |
| 185 | + |
| 186 | + batch_size, seq_len, n_q_head, head_dim = dq.shape |
| 187 | + n_kv_head = dk.shape[2] |
| 188 | + pad_hd = triton.next_power_of_2(head_dim) |
| 189 | + pad_n_q_head = triton.next_power_of_2(n_q_head) |
| 190 | + pad_n_kv_head = triton.next_power_of_2(n_kv_head) |
| 191 | + |
| 192 | + n_row = batch_size * seq_len |
| 193 | + |
| 194 | + # ensure dq and dk are contiguous |
| 195 | + dq = dq.contiguous() |
| 196 | + dk = dk.contiguous() |
| 197 | + |
| 198 | + # Compute tiling strategy based on UB capacity |
| 199 | + dtype_size = dq.element_size() |
| 200 | + # MROPE backward tiling strategy: |
| 201 | + # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each |
| 202 | + # - In q heads loop (peak memory): |
| 203 | + # * q_left: BLOCK_Q * (pad_hd // 2) elements |
| 204 | + # * q_right: BLOCK_Q * (pad_hd // 2) elements |
| 205 | + # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result) |
| 206 | + # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result) |
| 207 | + # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements |
| 208 | + # - In k heads loop (peak memory): |
| 209 | + # * k_left: BLOCK_K * (pad_hd // 2) elements |
| 210 | + # * k_right: BLOCK_K * (pad_hd // 2) elements |
| 211 | + # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result) |
| 212 | + # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result) |
| 213 | + # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements |
| 214 | + # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case |
| 215 | + # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements |
| 216 | + # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits |
| 217 | + # - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits |
| 218 | + # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits |
| 219 | + # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) |
| 220 | + # - tiling_dims: (0, 0) means first dimension of each shape can be tiled |
| 221 | + # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd)) |
| 222 | + shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) |
| 223 | + tile_shapes = compute_default_tiling_strategy( |
| 224 | + safety_margin=0.90, |
| 225 | + dtype_size=dtype_size, |
| 226 | + memory_multiplier=3.0, |
| 227 | + shapes=shapes, |
| 228 | + tiling_dims=(0, 0), |
| 229 | + ) |
| 230 | + |
| 231 | + if tile_shapes is not None and len(tile_shapes) == len(shapes): |
| 232 | + # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd)) |
| 233 | + q_tile_shape, k_tile_shape = tile_shapes |
| 234 | + BLOCK_Q, _ = q_tile_shape |
| 235 | + BLOCK_K, _ = k_tile_shape |
| 236 | + else: |
| 237 | + # Fallback to conservative defaults |
| 238 | + BLOCK_Q = triton.next_power_of_2(pad_n_q_head) |
| 239 | + BLOCK_K = triton.next_power_of_2(pad_n_kv_head) |
| 240 | + _triton_qwen2vl_mrope_npu[(n_row,)]( |
| 241 | + dq, |
| 242 | + dq.stride(1), |
| 243 | + dk, |
| 244 | + dk.stride(1), |
| 245 | + cos, |
| 246 | + sin, |
| 247 | + seq_len, |
| 248 | + batch_size, |
| 249 | + n_q_head, |
| 250 | + n_kv_head, |
| 251 | + head_dim, |
| 252 | + mrope_section[0], |
| 253 | + mrope_section[1], |
| 254 | + BLOCK_Q, |
| 255 | + BLOCK_K, |
| 256 | + BACKWARD_PASS=True, |
| 257 | + ) |
| 258 | + return dq.transpose(1, 2), dk.transpose(1, 2) |
| 259 | + |
| 260 | + |
| 261 | +class LigerQwen2VLMRopeFunction(torch.autograd.Function): |
| 262 | + @staticmethod |
| 263 | + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): |
| 264 | + """ |
| 265 | + q size: (bsz, n_q_head, seq_len, head_dim) |
| 266 | + k size: (bsz, n_kv_head, seq_len, head_dim) |
| 267 | + cos size: (3, bsz, seq_len, head_dim) |
| 268 | + sin size: (3, bsz, seq_len, head_dim) |
| 269 | + """ |
| 270 | + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) |
| 271 | + ctx.save_for_backward(cos, sin) |
| 272 | + ctx.mrope_section = mrope_section |
| 273 | + return q, k |
| 274 | + |
| 275 | + def backward(ctx, dq, dk): |
| 276 | + """ |
| 277 | + dq size: (bsz, n_q_head, seq_len, head_dim) |
| 278 | + dk size: (bsz, n_kv_head, seq_len, head_dim) |
| 279 | + cos size: (3, bsz, seq_len, head_dim) |
| 280 | + sin size: (3, bsz, seq_len, head_dim) |
| 281 | + """ |
| 282 | + cos, sin = ctx.saved_tensors |
| 283 | + mrope_section = ctx.mrope_section |
| 284 | + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) |
| 285 | + return dq, dk, None, None, None, None |
0 commit comments