Skip to content

Commit fc88829

Browse files
yzh119cyx-6
andauthored
feat: Fused rope fp8 quantize kernel for MLA (#1339)
<!-- .github/pull_request_template.md --> ## 📌 Description Fusing RoPE + fp8 quantization kernel to prepare input for fp8 mla kernel. Reference: https://github.com/NVIDIA/TensorRT-LLM/blob/0df758ec9f8409410bac8b60d117374054391c2d/cpp/tensorrt_llm/kernels/mlaKernels.cu#L358 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Yaxing Cai <[email protected]>
1 parent b7bfd00 commit fc88829

File tree

6 files changed

+505
-5
lines changed

6 files changed

+505
-5
lines changed

csrc/flashinfer_rope_ops.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r
3939
at::Tensor k_rope, at::Tensor cos_sin_cache,
4040
at::Tensor pos_ids, bool interleave);
4141

42+
void mla_rope_quantize(at::Tensor q_rope_in, at::Tensor k_rope_in, at::Tensor q_nope_in,
43+
at::Tensor k_nope_in, at::Tensor q_rope_out, at::Tensor k_rope_out,
44+
at::Tensor q_nope_out, at::Tensor k_nope_out, at::Tensor cos_sin_cache,
45+
at::Tensor pos_ids, double quant_scale_q, double quant_scale_kv,
46+
bool interleave);
47+
4248
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
4349
// "Apply RoPE"
4450
m.def("apply_rope", apply_rope);
@@ -50,4 +56,6 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
5056
m.def("apply_llama31_rope_pos_ids", apply_llama31_rope_pos_ids);
5157
// "Apply RoPE with positional ids and cosine/sine cache"
5258
m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache);
59+
// "MLA RoPE Quantize"
60+
m.def("mla_rope_quantize", mla_rope_quantize);
5361
}

csrc/rope.cu

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,98 @@ void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, a
259259
});
260260
});
261261
}
262+
263+
void mla_rope_quantize(at::Tensor q_rope_in, at::Tensor k_rope_in, at::Tensor q_nope_in,
264+
at::Tensor k_nope_in, at::Tensor q_rope_out, at::Tensor k_rope_out,
265+
at::Tensor q_nope_out, at::Tensor k_nope_out, at::Tensor cos_sin_cache,
266+
at::Tensor pos_ids, double quant_scale_q, double quant_scale_kv,
267+
bool interleave) {
268+
CHECK_LAST_DIM_CONTIGUOUS(q_rope_in);
269+
CHECK_LAST_DIM_CONTIGUOUS(k_rope_in);
270+
CHECK_LAST_DIM_CONTIGUOUS(q_nope_in);
271+
CHECK_LAST_DIM_CONTIGUOUS(k_nope_in);
272+
CHECK_LAST_DIM_CONTIGUOUS(q_rope_out);
273+
CHECK_LAST_DIM_CONTIGUOUS(k_rope_out);
274+
CHECK_LAST_DIM_CONTIGUOUS(q_nope_out);
275+
CHECK_LAST_DIM_CONTIGUOUS(k_nope_out);
276+
CHECK_INPUT(cos_sin_cache);
277+
CHECK_INPUT(pos_ids);
278+
279+
CHECK_EQ(q_rope_in.size(-1), 64);
280+
CHECK_EQ(k_rope_in.size(-1), 64);
281+
CHECK_EQ(q_nope_in.size(-1), 512);
282+
CHECK_EQ(k_nope_in.size(-1), 512);
283+
CHECK_EQ(q_rope_out.size(-1), 64);
284+
CHECK_EQ(k_rope_out.size(-1), 64);
285+
CHECK_EQ(q_nope_out.size(-1), 512);
286+
CHECK_EQ(k_nope_out.size(-1), 512);
287+
auto scalar_type_in = q_rope_in.scalar_type();
288+
TORCH_CHECK(scalar_type_in == k_rope_in.scalar_type());
289+
TORCH_CHECK(scalar_type_in == q_nope_in.scalar_type());
290+
TORCH_CHECK(scalar_type_in == k_nope_in.scalar_type());
291+
auto quant_type_out = q_rope_out.scalar_type();
292+
TORCH_CHECK(quant_type_out == k_rope_out.scalar_type());
293+
TORCH_CHECK(quant_type_out == q_nope_out.scalar_type());
294+
TORCH_CHECK(quant_type_out == k_nope_out.scalar_type());
295+
296+
CHECK_DIM(3, q_rope_in); // q_rope_in: (nnz, H_Q, 64)
297+
CHECK_DIM(3, q_nope_in); // q_nope_in: (nnz, H_Q, 512)
298+
CHECK_DIM(2, k_rope_in); // k_rope_in: (nnz, 64)
299+
CHECK_DIM(2, k_nope_in); // k_nope_in: (nnz, 512)
300+
CHECK_DIM(3, q_rope_out); // q_rope_out: (nnz, H_Q, 64)
301+
CHECK_DIM(3, q_nope_out); // q_nope_out: (nnz, H_Q, 512)
302+
CHECK_DIM(2, k_rope_out); // k_rope_out: (nnz, 64)
303+
CHECK_DIM(2, k_nope_out); // k_nope_out: (nnz, 512)
304+
uint32_t nnz = q_rope_in.size(0);
305+
CHECK_EQ(q_nope_in.size(0), nnz);
306+
CHECK_EQ(k_nope_in.size(0), nnz);
307+
CHECK_EQ(q_rope_out.size(0), nnz);
308+
CHECK_EQ(k_rope_out.size(0), nnz);
309+
CHECK_EQ(q_nope_out.size(0), nnz);
310+
CHECK_EQ(k_nope_out.size(0), nnz);
311+
uint32_t num_heads = q_rope_in.size(1);
312+
CHECK_EQ(q_rope_in.size(1), num_heads);
313+
CHECK_EQ(q_nope_in.size(1), num_heads);
314+
CHECK_EQ(q_rope_out.size(1), num_heads);
315+
CHECK_EQ(q_nope_out.size(1), num_heads);
316+
317+
const uint32_t q_rope_in_stride_n = q_rope_in.stride(0);
318+
const uint32_t q_rope_in_stride_h = q_rope_in.stride(1);
319+
const uint32_t q_nope_in_stride_n = q_nope_in.stride(0);
320+
const uint32_t q_nope_in_stride_h = q_nope_in.stride(1);
321+
const uint32_t q_rope_out_stride_n = q_rope_out.stride(0);
322+
const uint32_t q_rope_out_stride_h = q_rope_out.stride(1);
323+
const uint32_t q_nope_out_stride_n = q_nope_out.stride(0);
324+
const uint32_t q_nope_out_stride_h = q_nope_out.stride(1);
325+
const uint32_t k_rope_in_stride = k_rope_in.stride(0);
326+
const uint32_t k_nope_in_stride = k_nope_in.stride(0);
327+
const uint32_t k_rope_out_stride = k_rope_out.stride(0);
328+
const uint32_t k_nope_out_stride = k_nope_out.stride(0);
329+
330+
const c10::cuda::OptionalCUDAGuard device_guard(q_rope_in.device());
331+
auto stream = at::cuda::getCurrentCUDAStream();
332+
333+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(scalar_type_in, c_type, [&] {
334+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(quant_type_out, c_quant_type, [&] {
335+
return DISPATCH_PYTORCH_IDTYPE_TO_CTYPE(pos_ids.scalar_type(), c_idtype, [&] {
336+
cudaError_t status = MLARopeQuantize(
337+
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
338+
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
339+
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
340+
static_cast<c_quant_type*>(k_rope_out.data_ptr()),
341+
static_cast<c_quant_type*>(q_nope_out.data_ptr()),
342+
static_cast<c_quant_type*>(k_nope_out.data_ptr()),
343+
static_cast<float*>(cos_sin_cache.data_ptr()),
344+
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_heads, q_rope_in_stride_n,
345+
q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, q_rope_out_stride_n,
346+
q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, k_rope_in_stride,
347+
k_nope_in_stride, k_rope_out_stride, k_nope_out_stride, quant_scale_q, quant_scale_kv,
348+
interleave, stream);
349+
TORCH_CHECK(status == cudaSuccess,
350+
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " +
351+
std::string(cudaGetErrorString(status)));
352+
return true;
353+
});
354+
});
355+
});
356+
}

flashinfer/rope.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,61 @@ def _fake_apply_rope_pos_ids(
175175
pass
176176

177177

178+
@register_custom_op(
179+
"flashinfer::mla_rope_quantize",
180+
mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"),
181+
)
182+
def _mla_rope_quantize(
183+
q_rope_in: torch.Tensor,
184+
k_rope_in: torch.Tensor,
185+
q_nope_in: torch.Tensor,
186+
k_nope_in: torch.Tensor,
187+
cos_sin_cache: torch.Tensor,
188+
pos_ids: torch.Tensor,
189+
q_rope_out: torch.Tensor,
190+
k_rope_out: torch.Tensor,
191+
q_nope_out: torch.Tensor,
192+
k_nope_out: torch.Tensor,
193+
quant_scale_q: float,
194+
quant_scale_kv: float,
195+
interleave: bool,
196+
) -> None:
197+
get_rope_module().mla_rope_quantize(
198+
q_rope_in,
199+
k_rope_in,
200+
q_nope_in,
201+
k_nope_in,
202+
q_rope_out,
203+
k_rope_out,
204+
q_nope_out,
205+
k_nope_out,
206+
cos_sin_cache,
207+
pos_ids,
208+
quant_scale_q,
209+
quant_scale_kv,
210+
interleave,
211+
)
212+
213+
214+
@register_fake_op("flashinfer::mla_rope_quantize")
215+
def _fake_mla_rope_quantize(
216+
q_rope_in: torch.Tensor,
217+
k_rope_in: torch.Tensor,
218+
q_nope_in: torch.Tensor,
219+
k_nope_in: torch.Tensor,
220+
cos_sin_cache: torch.Tensor,
221+
pos_ids: torch.Tensor,
222+
q_rope_out: torch.Tensor,
223+
k_rope_out: torch.Tensor,
224+
q_nope_out: torch.Tensor,
225+
k_nope_out: torch.Tensor,
226+
quant_scale_q: float,
227+
quant_scale_kv: float,
228+
interleave: bool,
229+
) -> None:
230+
pass
231+
232+
178233
@register_custom_op(
179234
"flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope")
180235
)
@@ -1094,3 +1149,72 @@ def apply_rope_with_cos_sin_cache_inplace(
10941149
pos_ids=positions,
10951150
interleave=(not is_neox),
10961151
)
1152+
1153+
1154+
def mla_rope_quantize_fp8(
1155+
q_rope: torch.Tensor,
1156+
k_rope: torch.Tensor,
1157+
q_nope: torch.Tensor,
1158+
k_nope: torch.Tensor,
1159+
cos_sin_cache: torch.Tensor,
1160+
pos_ids: torch.Tensor,
1161+
is_neox: bool = True,
1162+
quantize_dtype: Optional[torch.dtype] = None,
1163+
quant_scale_q: float = 1.0,
1164+
quant_scale_kv: float = 1.0,
1165+
q_rope_out: Optional[torch.Tensor] = None,
1166+
k_rope_out: Optional[torch.Tensor] = None,
1167+
q_nope_out: Optional[torch.Tensor] = None,
1168+
k_nope_out: Optional[torch.Tensor] = None,
1169+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1170+
if cos_sin_cache.dtype != torch.float32:
1171+
raise ValueError("cos_sin_cache should be float32")
1172+
1173+
# Infer quantize_dtype from output tensors or default to float8_e4m3fn
1174+
if quantize_dtype is None:
1175+
for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out):
1176+
if out is not None:
1177+
quantize_dtype = out.dtype
1178+
break
1179+
else:
1180+
quantize_dtype = torch.float8_e4m3fn
1181+
1182+
# Allocate output tensors if not provided
1183+
q_rope_out = (
1184+
q_rope_out
1185+
if q_rope_out is not None
1186+
else torch.empty_like(q_rope, dtype=quantize_dtype)
1187+
)
1188+
k_rope_out = (
1189+
k_rope_out
1190+
if k_rope_out is not None
1191+
else torch.empty_like(k_rope, dtype=quantize_dtype)
1192+
)
1193+
q_nope_out = (
1194+
q_nope_out
1195+
if q_nope_out is not None
1196+
else torch.empty_like(q_nope, dtype=quantize_dtype)
1197+
)
1198+
k_nope_out = (
1199+
k_nope_out
1200+
if k_nope_out is not None
1201+
else torch.empty_like(k_nope, dtype=quantize_dtype)
1202+
)
1203+
1204+
_mla_rope_quantize(
1205+
q_rope,
1206+
k_rope,
1207+
q_nope,
1208+
k_nope,
1209+
cos_sin_cache,
1210+
pos_ids,
1211+
q_rope_out,
1212+
k_rope_out,
1213+
q_nope_out,
1214+
k_nope_out,
1215+
quant_scale_q,
1216+
quant_scale_kv,
1217+
not is_neox, # interleave
1218+
)
1219+
1220+
return q_rope_out, k_rope_out, q_nope_out, k_nope_out

0 commit comments

Comments
 (0)