|
2 | 2 |
|
3 | 3 | import pytest
|
4 | 4 | import torch
|
5 |
| -from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant |
| 5 | +from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant |
6 | 6 |
|
7 | 7 | from flashinfer import (
|
8 | 8 | e2m1_and_ufp8sf_scale_to_float,
|
@@ -88,30 +88,47 @@ def unswizzle_sf(
|
88 | 88 | @pytest.mark.parametrize("shape", SHAPES)
|
89 | 89 | @pytest.mark.parametrize("seed", SEEDS)
|
90 | 90 | @pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 91 | +@pytest.mark.parametrize("sf_use_ue8m0", [False, True]) |
| 92 | +@pytest.mark.parametrize("is_swizzled", [False, True]) |
91 | 93 | @torch.inference_mode()
|
92 | 94 | def test_fp4_quantization(
|
93 | 95 | dtype: torch.dtype,
|
94 | 96 | shape: tuple[int, int],
|
95 | 97 | seed: int,
|
96 | 98 | device: str,
|
| 99 | + sf_use_ue8m0: bool, |
| 100 | + is_swizzled: bool, |
97 | 101 | ) -> None:
|
98 | 102 | if not is_sm100a_supported(torch.device(device)):
|
99 | 103 | pytest.skip("Nvfp4 Requires compute capability of 10 or above")
|
100 | 104 | torch.set_default_device(device)
|
101 | 105 | torch.manual_seed(seed)
|
102 | 106 | m, n = shape
|
| 107 | + sf_vec_size = 32 if sf_use_ue8m0 else 16 |
103 | 108 | x = torch.randn((m, n), dtype=dtype)
|
104 | 109 | tensor_amax = torch.abs(x).max().to(torch.float32)
|
105 |
| - global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax |
106 |
| - out_ref, scale_ref = ref_nvfp4_quant(x, global_scale, BLOCK_SIZE) |
107 |
| - out, out_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False) |
108 |
| - assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible" |
109 |
| - scale_ans = recover_swizzled_scales( |
110 |
| - out_scale.reshape(-1, n // BLOCK_SIZE).view(torch.float8_e4m3fn), |
111 |
| - m, |
112 |
| - n, |
113 |
| - BLOCK_SIZE, |
| 110 | + if sf_use_ue8m0: |
| 111 | + global_scale = torch.tensor(1.0, dtype=torch.float32) |
| 112 | + else: |
| 113 | + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax |
| 114 | + out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size, sf_use_ue8m0) |
| 115 | + out, out_scale = fp4_quantize( |
| 116 | + x, global_scale, sf_vec_size, sf_use_ue8m0, is_swizzled |
114 | 117 | )
|
| 118 | + assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible" |
| 119 | + if sf_use_ue8m0: |
| 120 | + out_scale = (out_scale.to(torch.int32) << 23).view(torch.float32) |
| 121 | + else: |
| 122 | + out_scale = out_scale.view(torch.float8_e4m3fn).to(torch.float32) |
| 123 | + if is_swizzled: |
| 124 | + scale_ans = recover_swizzled_scales( |
| 125 | + out_scale.reshape(-1, n // sf_vec_size), |
| 126 | + m, |
| 127 | + n, |
| 128 | + sf_vec_size, |
| 129 | + ) |
| 130 | + else: |
| 131 | + scale_ans = out_scale |
115 | 132 | out_ans = cast_from_fp4(out).reshape(m, n)
|
116 | 133 | torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=1e-1)
|
117 | 134 | torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1)
|
|
0 commit comments