|
5 | 5 |
|
6 | 6 | import vllm._custom_ops as ops
|
7 | 7 | from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
| 8 | +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( |
| 9 | + rocm_per_tensor_w8a8_scaled_mm_impl) |
8 | 10 | from vllm.platforms import current_platform
|
9 | 11 |
|
10 | 12 | DTYPES = [torch.bfloat16, torch.float16]
|
@@ -116,3 +118,32 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
116 | 118 | current_platform.get_cu_count())
|
117 | 119 |
|
118 | 120 | assert torch.allclose(out, ref_out, rtol=0.01)
|
| 121 | + |
| 122 | + |
| 123 | +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) |
| 124 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 125 | +@pytest.mark.parametrize("seed", SEEDS) |
| 126 | +@pytest.mark.parametrize("use_bias", [True, False]) |
| 127 | +@pytest.mark.skipif( |
| 128 | + not (current_platform.is_rocm() and current_platform.supports_fp8()), |
| 129 | + reason="only test for rocm fp8") |
| 130 | +def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias): |
| 131 | + torch.manual_seed(seed) |
| 132 | + |
| 133 | + A = torch.rand(n, k, device="cuda") |
| 134 | + B = torch.rand(m, k, device="cuda") |
| 135 | + |
| 136 | + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) |
| 137 | + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) |
| 138 | + |
| 139 | + bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None |
| 140 | + |
| 141 | + output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a, |
| 142 | + scale_b, bias) |
| 143 | + ref_out = torch._scaled_mm(A, |
| 144 | + B.t(), |
| 145 | + out_dtype=dtype, |
| 146 | + scale_a=scale_a, |
| 147 | + scale_b=scale_b, |
| 148 | + bias=bias) |
| 149 | + assert torch.allclose(output, ref_out, rtol=0.01) |
0 commit comments