|
3 | 3 |
|
4 | 4 | import comfy_kitchen as ck |
5 | 5 | from comfy_kitchen.float_utils import ( |
6 | | - F4_E2M1_EPS, |
7 | 6 | F4_E2M1_MAX, |
8 | | - F8_E4M3_EPS, |
9 | 7 | F8_E4M3_MAX, |
10 | 8 | fp4_x2_to_f32, |
11 | 9 | ) |
@@ -146,67 +144,48 @@ def capable_backends(self, device): |
146 | 144 |
|
147 | 145 | @pytest.mark.parametrize("m,k", [ |
148 | 146 | (1024, 2048), |
| 147 | + (512, 1024), |
149 | 148 | (129, 128), # Edge case: odd rows requiring padding |
150 | 149 | (33, 65), # Edge case: both dimensions odd |
151 | 150 | ]) |
152 | 151 | def test_quantize_nvfp4_all_backends(self, capable_backends, device, seed, m, k): |
153 | | - """Test NVFP4 quantization across all capable backends.""" |
154 | | - for backend_name in capable_backends: |
155 | | - inputs = ConstraintAwareTestInputs("quantize_nvfp4", backend_name, device) |
156 | | - x = inputs.tensor("x", shape=(m, k), dtype=torch.bfloat16) |
157 | | - x = x * 4 # Scale up for better test coverage |
158 | | - |
159 | | - scale = torch.max(torch.abs(x)) / (F8_E4M3_MAX * F4_E2M1_MAX) |
160 | | - scale = scale.to(torch.float32) |
161 | | - |
162 | | - needs_padding = (m % 16 != 0) or (k % 16 != 0) |
163 | | - |
164 | | - with ck.use_backend(backend_name): |
165 | | - qx, sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding) |
166 | | - |
167 | | - assert qx.dtype == torch.uint8 |
168 | | - assert sx.dtype == torch.float8_e4m3fn |
169 | | - |
170 | | - @pytest.mark.parametrize("m,k", [(512, 1024)]) |
171 | | - def test_quantize_nvfp4_cross_backend_consistency( |
172 | | - self, capable_backends, device, seed, m, k |
173 | | - ): |
174 | | - """Test that all backends produce consistent NVFP4 results.""" |
175 | | - if len(capable_backends) < 2: |
176 | | - pytest.skip("Need at least 2 backends for cross-validation") |
| 152 | + """Test NVFP4 quantization across all capable backends with accuracy testing.""" |
| 153 | + if "eager" not in capable_backends: |
| 154 | + pytest.skip("Need eager backend as reference") |
177 | 155 |
|
| 156 | + # Create test input |
178 | 157 | x = torch.randn(m, k, device=device, dtype=torch.bfloat16) * 4 |
179 | 158 | scale = torch.max(torch.abs(x)) / (F8_E4M3_MAX * F4_E2M1_MAX) |
180 | 159 | scale = scale.to(torch.float32) |
| 160 | + needs_padding = (m % 16 != 0) or (k % 16 != 0) |
| 161 | + |
| 162 | + with ck.use_backend("eager"): |
| 163 | + ref_qx, ref_sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding) |
181 | 164 |
|
182 | | - results = {} |
183 | 165 | for backend_name in capable_backends: |
184 | 166 | with ck.use_backend(backend_name): |
185 | | - qx, sx = ck.quantize_nvfp4(x, scale) |
186 | | - results[backend_name] = (qx, sx) |
| 167 | + qx, sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding) |
187 | 168 |
|
188 | | - # Compare all against first |
189 | | - ref_backend = capable_backends[0] |
190 | | - ref_qx, ref_sx = results[ref_backend] |
| 169 | + # Check basic properties |
| 170 | + assert qx.dtype == torch.uint8 |
| 171 | + assert sx.dtype == torch.float8_e4m3fn |
191 | 172 |
|
192 | | - for backend_name, (qx, sx) in results.items(): |
193 | | - if backend_name != ref_backend: |
194 | 173 | assert_values_close( |
195 | 174 | sx.to(torch.float32), |
196 | 175 | ref_sx.to(torch.float32), |
197 | | - rtol=F8_E4M3_EPS, |
198 | | - atol=F8_E4M3_EPS, |
199 | | - name=f"scales ({backend_name} vs {ref_backend})" |
| 176 | + rtol=1e-5, |
| 177 | + atol=1e-3, |
| 178 | + name=f"scales ({backend_name} vs eager)" |
200 | 179 | ) |
201 | 180 |
|
202 | 181 | qx_f32 = fp4_x2_to_f32(qx) |
203 | 182 | ref_qx_f32 = fp4_x2_to_f32(ref_qx) |
204 | 183 | assert_values_close( |
205 | 184 | qx_f32, |
206 | 185 | ref_qx_f32, |
207 | | - rtol=F4_E2M1_EPS, |
208 | | - atol=F4_E2M1_EPS, |
209 | | - name=f"quantized ({backend_name} vs {ref_backend})" |
| 186 | + rtol=1e-2, |
| 187 | + atol=2.0, |
| 188 | + name=f"quantized data ({backend_name} vs eager)" |
210 | 189 | ) |
211 | 190 |
|
212 | 191 | def test_quantize_nvfp4_cpu_fallback(self, seed): |
@@ -235,28 +214,48 @@ def capable_backends(self, device): |
235 | 214 | pytest.skip(f"No backend supports dequantize_nvfp4 on {device}") |
236 | 215 | return backends |
237 | 216 |
|
238 | | - @pytest.mark.parametrize("m,k", [(1024, 2048), (512, 4096)]) |
| 217 | + @pytest.mark.parametrize("m,k", [ |
| 218 | + (1024, 2048), |
| 219 | + (512, 4096), |
| 220 | + (129, 128), # Edge case with padding |
| 221 | + ]) |
239 | 222 | @pytest.mark.parametrize("output_dtype", [torch.float16, torch.bfloat16]) |
240 | 223 | def test_dequantize_nvfp4_all_backends( |
241 | 224 | self, capable_backends, device, seed, m, k, output_dtype |
242 | 225 | ): |
243 | | - """Test NVFP4 dequantization across all capable backends.""" |
| 226 | + """Test NVFP4 dequantization across all capable backends with accuracy testing.""" |
| 227 | + if "eager" not in capable_backends: |
| 228 | + pytest.skip("Need eager backend as reference") |
| 229 | + |
244 | 230 | x = torch.randn(m, k, device=device, dtype=torch.bfloat16) * 4 |
245 | 231 | scale = torch.max(torch.abs(x)) / (F8_E4M3_MAX * F4_E2M1_MAX) |
246 | 232 | scale = scale.to(torch.float32) |
| 233 | + needs_padding = (m % 16 != 0) or (k % 16 != 0) |
247 | 234 |
|
248 | 235 | # Quantize with eager |
249 | 236 | with ck.use_backend("eager"): |
250 | | - qx, sx = ck.quantize_nvfp4(x, scale) |
| 237 | + qx, sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding) |
| 238 | + ref_result = ck.dequantize_nvfp4(qx, scale, sx, output_type=output_dtype) |
| 239 | + # Unpad if needed |
| 240 | + ref_result = ref_result[:m, :k] |
251 | 241 |
|
252 | 242 | for backend_name in capable_backends: |
253 | 243 | with ck.use_backend(backend_name): |
254 | 244 | result = ck.dequantize_nvfp4(qx, scale, sx, output_type=output_dtype) |
| 245 | + result = result[:m, :k] # Unpad if needed |
255 | 246 |
|
256 | | - assert result.shape == x.shape |
| 247 | + assert result.shape == (m, k) |
257 | 248 | assert result.dtype == output_dtype |
258 | 249 | assert result.device == x.device |
259 | 250 |
|
| 251 | + assert_values_close( |
| 252 | + result, |
| 253 | + ref_result, |
| 254 | + rtol=1e-3, |
| 255 | + atol=1e-2, |
| 256 | + name=f"dequantized output ({backend_name} vs eager)" |
| 257 | + ) |
| 258 | + |
260 | 259 |
|
261 | 260 | class TestScaledMMNVFP4: |
262 | 261 | """NVFP4 matrix multiplication tests.""" |
|
0 commit comments