Skip to content

Commit fe442a2

Browse files
authored
unittest: Add Mxfp4 trtllm-gen moe unit tests (#1399)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description - add mxfp4 quantization unit test - add mxfp4 x mxfp8 and mxfp4 x bf16 moe unit test ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: siyuanf <[email protected]>
1 parent 872fe4b commit fe442a2

File tree

5 files changed

+197
-94
lines changed

5 files changed

+197
-94
lines changed

β€Žtests/test_fp4_quantize.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant
5+
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant
66

77
from flashinfer import (
88
e2m1_and_ufp8sf_scale_to_float,
@@ -88,30 +88,47 @@ def unswizzle_sf(
8888
@pytest.mark.parametrize("shape", SHAPES)
8989
@pytest.mark.parametrize("seed", SEEDS)
9090
@pytest.mark.parametrize("device", CUDA_DEVICES)
91+
@pytest.mark.parametrize("sf_use_ue8m0", [False, True])
92+
@pytest.mark.parametrize("is_swizzled", [False, True])
9193
@torch.inference_mode()
9294
def test_fp4_quantization(
9395
dtype: torch.dtype,
9496
shape: tuple[int, int],
9597
seed: int,
9698
device: str,
99+
sf_use_ue8m0: bool,
100+
is_swizzled: bool,
97101
) -> None:
98102
if not is_sm100a_supported(torch.device(device)):
99103
pytest.skip("Nvfp4 Requires compute capability of 10 or above")
100104
torch.set_default_device(device)
101105
torch.manual_seed(seed)
102106
m, n = shape
107+
sf_vec_size = 32 if sf_use_ue8m0 else 16
103108
x = torch.randn((m, n), dtype=dtype)
104109
tensor_amax = torch.abs(x).max().to(torch.float32)
105-
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
106-
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale, BLOCK_SIZE)
107-
out, out_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False)
108-
assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible"
109-
scale_ans = recover_swizzled_scales(
110-
out_scale.reshape(-1, n // BLOCK_SIZE).view(torch.float8_e4m3fn),
111-
m,
112-
n,
113-
BLOCK_SIZE,
110+
if sf_use_ue8m0:
111+
global_scale = torch.tensor(1.0, dtype=torch.float32)
112+
else:
113+
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
114+
out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size, sf_use_ue8m0)
115+
out, out_scale = fp4_quantize(
116+
x, global_scale, sf_vec_size, sf_use_ue8m0, is_swizzled
114117
)
118+
assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible"
119+
if sf_use_ue8m0:
120+
out_scale = (out_scale.to(torch.int32) << 23).view(torch.float32)
121+
else:
122+
out_scale = out_scale.view(torch.float8_e4m3fn).to(torch.float32)
123+
if is_swizzled:
124+
scale_ans = recover_swizzled_scales(
125+
out_scale.reshape(-1, n // sf_vec_size),
126+
m,
127+
n,
128+
sf_vec_size,
129+
)
130+
else:
131+
scale_ans = out_scale
115132
out_ans = cast_from_fp4(out).reshape(m, n)
116133
torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=1e-1)
117134
torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1)

β€Žtests/test_trtllm_gen_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant
5+
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant
66

77
import flashinfer
88
from flashinfer.utils import FP4Tensor
@@ -437,7 +437,7 @@ def test_trtllm_batch_prefill(
437437

438438
if o_dtype == "nvfp4":
439439
output = cast_from_fp4(output)
440-
output_ref, out_scale_factor_ref = ref_nvfp4_quant(output_ref, o_sf_scale, 16)
440+
output_ref, out_scale_factor_ref = ref_fp4_quant(output_ref, o_sf_scale, 16)
441441
out_scale_factor = recover_swizzled_scales(
442442
out_scale_factor,
443443
output.shape[0],

β€Žtests/test_trtllm_gen_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55
import torch.nn.functional as F
6-
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant
6+
from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant
77

88
import flashinfer
99
from flashinfer.utils import FP4Tensor
@@ -328,7 +328,7 @@ def test_trtllm_batch_decode_fmha(
328328

329329
if o_dtype == "nvfp4":
330330
output = cast_from_fp4(output)
331-
output_ref, out_scale_factor_ref = ref_nvfp4_quant(output_ref, o_sf_scale, 16)
331+
output_ref, out_scale_factor_ref = ref_fp4_quant(output_ref, o_sf_scale, 16)
332332
out_scale_factor = recover_swizzled_scales(
333333
out_scale_factor,
334334
output.shape[0],

0 commit comments

Comments
Β (0)