Skip to content

Commit b513f86

Browse files
authored
fix: separate out fp4 lib into sm90 and sm100 versions, add oob checking in fused moe (#1565)
## 📌 Description This fixes an OOB issue in the fused MoE and creates a separate sm90 and sm100 path fp4 quantization. ## 🔍 Related Issues fix required for vllm-project/vllm#23369 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes --------- Signed-off-by: Duncan Moss <[email protected]>
1 parent 4b30a91 commit b513f86

File tree

4 files changed

+58
-17
lines changed

4 files changed

+58
-17
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,6 @@ __global__ void finalizeMoeRoutingKernel(
17551755
ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row,
17561756
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token,
17571757
int const num_experts_per_node, int const start_expert_id) {
1758-
assert(orig_cols % 4 == 0);
17591758
int64_t const original_row = blockIdx.x;
17601759
int64_t const num_rows = gridDim.x;
17611760
auto const offset = original_row * orig_cols;
@@ -1765,6 +1764,8 @@ __global__ void finalizeMoeRoutingKernel(
17651764
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
17661765
128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
17671766

1767+
assert(orig_cols % FINALIZE_ELEM_PER_THREAD == 0);
1768+
17681769
int64_t const start_offset = threadIdx.x;
17691770
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
17701771
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
@@ -1795,6 +1796,11 @@ __global__ void finalizeMoeRoutingKernel(
17951796
int64_t const expanded_original_row = original_row + k_idx * num_rows;
17961797
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
17971798

1799+
int64_t expanded_rows = num_rows * experts_per_token;
1800+
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
1801+
continue;
1802+
}
1803+
17981804
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
17991805

18001806
auto const* expanded_permuted_rows_row_ptr =
@@ -1828,8 +1834,6 @@ __global__ void finalizeMoeRoutingNoFillingKernel(
18281834
int const* permuted_row_to_unpermuted_row, int const* token_selected_experts,
18291835
int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const orig_cols,
18301836
int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) {
1831-
assert(orig_cols % 4 == 0);
1832-
18331837
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
18341838
asm volatile("griddepcontrol.wait;");
18351839
#endif
@@ -1864,6 +1868,8 @@ __global__ void finalizeMoeRoutingNoFillingKernel(
18641868
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
18651869
128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
18661870

1871+
assert(orig_cols % FINALIZE_ELEM_PER_THREAD == 0);
1872+
18671873
int64_t const start_offset = threadIdx.x;
18681874
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
18691875
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
@@ -1889,6 +1895,11 @@ __global__ void finalizeMoeRoutingNoFillingKernel(
18891895

18901896
int64_t const expanded_permuted_row_from_k_idx =
18911897
unpermuted_row_to_permuted_row[source_row + k_idx * num_rows];
1898+
int64_t valid_tokens = expert_first_token_offset[num_experts_per_node];
1899+
if (expanded_permuted_row_from_k_idx < 0 ||
1900+
expanded_permuted_row_from_k_idx >= valid_tokens) {
1901+
continue;
1902+
}
18921903

18931904
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
18941905

flashinfer/aot.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
from .activation import act_func_def_str, gen_act_and_mul_module
1313
from .cascade import gen_cascade_module
14-
from .fp4_quantization import gen_fp4_quantization_module
14+
from .fp4_quantization import (
15+
gen_fp4_quantization_sm100_module,
16+
gen_fp4_quantization_sm90_module,
17+
)
1518
from .fused_moe import (
1619
gen_cutlass_fused_moe_sm100_module,
1720
gen_cutlass_fused_moe_sm90_module,
@@ -332,11 +335,12 @@ def gen_all_modules(
332335

333336
if add_moe:
334337
jit_specs.append(gen_gemm_module())
335-
jit_specs.append(gen_fp4_quantization_module())
336338
if has_sm90:
337339
jit_specs.append(gen_gemm_sm90_module())
340+
jit_specs.append(gen_fp4_quantization_sm90_module())
338341
jit_specs.append(gen_cutlass_fused_moe_sm90_module())
339342
if has_sm100:
343+
jit_specs.append(gen_fp4_quantization_sm100_module())
340344
jit_specs.append(gen_cutlass_fused_moe_sm100_module())
341345
jit_specs.append(gen_gemm_sm100_module())
342346

flashinfer/fp4_quantization.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import functools
1818
from enum import Enum
1919
from types import SimpleNamespace
20-
from typing import Optional, Tuple
20+
from typing import List, Optional, Tuple
2121

2222
import torch
2323

2424
from .jit import JitSpec
2525
from .jit import env as jit_env
26-
from .jit import gen_jit_spec, sm100a_nvcc_flags
26+
from .jit import gen_jit_spec, sm100a_nvcc_flags, sm90a_nvcc_flags
2727
from .utils import (
2828
device_support_pdl,
2929
get_shuffle_matrix_a_row_indices,
@@ -62,9 +62,17 @@ def _pad_scale_factors(
6262
).contiguous()
6363

6464

65-
def gen_fp4_quantization_module() -> JitSpec:
65+
def gen_fp4_quantization_sm100_module() -> JitSpec:
66+
return gen_fp4_quantization_module(sm100a_nvcc_flags, "100")
67+
68+
69+
def gen_fp4_quantization_sm90_module() -> JitSpec:
70+
return gen_fp4_quantization_module(sm90a_nvcc_flags, "90")
71+
72+
73+
def gen_fp4_quantization_module(nvcc_flags: List[str], device_arch: str) -> JitSpec:
6674
return gen_jit_spec(
67-
"fp4_quantization",
75+
f"fp4_quantization_{device_arch}",
6876
[
6977
jit_env.FLASHINFER_CSRC_DIR
7078
/ "nv_internal/tensorrt_llm/thop/fp4Quantize.cpp",
@@ -75,7 +83,7 @@ def gen_fp4_quantization_module() -> JitSpec:
7583
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp",
7684
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp",
7785
],
78-
extra_cuda_cflags=sm100a_nvcc_flags
86+
extra_cuda_cflags=nvcc_flags
7987
+ [
8088
"-DENABLE_BF16",
8189
"-DENABLE_FP8",
@@ -94,8 +102,13 @@ def gen_fp4_quantization_module() -> JitSpec:
94102

95103

96104
@functools.cache
97-
def get_fp4_quantization_module():
98-
module = gen_fp4_quantization_module().build_and_load()
105+
def get_fp4_quantization_module(backend: str = "100"):
106+
if backend == "100":
107+
module = gen_fp4_quantization_sm100_module().build_and_load()
108+
elif backend == "90":
109+
module = gen_fp4_quantization_sm90_module().build_and_load()
110+
else:
111+
raise ValueError(f"Invalid backend: {backend}")
99112

100113
@register_custom_op(
101114
"flashinfer::fp4_quantize_sm100",
@@ -310,7 +323,7 @@ def fp4_quantize(
310323
assert input.shape[-1] % sf_vec_size == 0
311324
if enable_pdl is None:
312325
enable_pdl = device_support_pdl(input.device)
313-
x_q, sf = get_fp4_quantization_module().fp4_quantize_sm100(
326+
x_q, sf = get_fp4_quantization_module("100").fp4_quantize_sm100(
314327
input,
315328
global_scale,
316329
sf_vec_size,
@@ -346,7 +359,11 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
346359
assert unswizzled_sf.dtype == torch.uint8, (
347360
f"Input dtype must be uint8, got {unswizzled_sf.dtype}"
348361
)
349-
return get_fp4_quantization_module().block_scale_interleave_sm100(
362+
363+
major, minor = torch.cuda.get_device_capability()
364+
device_arch = f"{major * 10 + minor}"
365+
366+
return get_fp4_quantization_module(device_arch).block_scale_interleave_sm100(
350367
unswizzled_sf,
351368
)
352369

@@ -380,7 +397,11 @@ def e2m1_and_ufp8sf_scale_to_float(
380397
torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
381398
382399
"""
383-
return get_fp4_quantization_module().e2m1_and_ufp8sf_scale_to_float_sm100(
400+
major, minor = torch.cuda.get_device_capability()
401+
device_arch = f"{major * 10 + minor}"
402+
return get_fp4_quantization_module(
403+
device_arch
404+
).e2m1_and_ufp8sf_scale_to_float_sm100(
384405
e2m1_tensor,
385406
ufp8_scale_tensor,
386407
global_scale_tensor,
@@ -547,7 +568,9 @@ def mxfp4_dequantize_host(
547568
Returns:
548569
torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
549570
"""
550-
return get_fp4_quantization_module().mxfp4_dequantize_host(
571+
major, minor = torch.cuda.get_device_capability()
572+
device_arch = f"{major * 10 + minor}"
573+
return get_fp4_quantization_module(device_arch).mxfp4_dequantize_host(
551574
weight,
552575
scale,
553576
group_size,

tests/test_fp4_quantize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def test_e2m1_dequantization(
295295
)
296296

297297

298-
def test_mxfp4_quantize_roundtrip():
298+
@pytest.mark.parametrize("device", CUDA_DEVICES)
299+
def test_mxfp4_quantize_roundtrip(device: str):
300+
if not is_sm100a_supported(torch.device(device)):
301+
pytest.skip("Nvfp4 Requires compute capability of 10 or above")
299302
x = torch.randn((128, 64), device="cuda", dtype=torch.bfloat16) / 10
300303

301304
quant_a, sfs = mxfp4_quantize(x)

0 commit comments

Comments
 (0)