Skip to content

Commit 7e92477

Browse files
Giuseppe5nickfraser
authored andcommitted
Reduce test combinations
1 parent 6fbeaf3 commit 7e92477

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tests/brevitas_ort/quant_module_cases.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
class QuantWBIOLCases:
1717

18-
@parametrize('rounding_type', ['round', 'floor'], ids=[r for r in ['round', 'floor']])
18+
@parametrize(
19+
'rounding_type', ['round', 'floor'], ids=[f'rtype_{r}' for r in ['round', 'floor']])
1920
@parametrize('impl', QUANT_WBIOL_IMPL, ids=[f'{c.__name__}' for c in QUANT_WBIOL_IMPL])
2021
@parametrize('input_bit_width', BIT_WIDTHS, ids=[f'i{b}' for b in BIT_WIDTHS])
2122
@parametrize('weight_bit_width', BIT_WIDTHS, ids=[f'w{b}' for b in BIT_WIDTHS])
@@ -37,9 +38,9 @@ def case_quant_wbiol(
3738
weight_quant, io_quant = quantizers
3839
is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat
3940
is_dynamic = io_quant == ShiftedUint8DynamicActPerTensorFloat
40-
if is_fp8:
41+
if is_fp8 or rounding_type == 'floor':
4142
if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8:
42-
pytest.skip('FP8 export requires total bitwidth equal to 8')
43+
pytest.skip('FP8 export and FLOOR rounding require all bitwidths equal to 8')
4344
torch.use_deterministic_algorithms(False)
4445
else:
4546
torch.use_deterministic_algorithms(True)

0 commit comments

Comments
 (0)