|
1 | | -import os, inspect |
| 1 | +import os, random |
2 | 2 | import pytest |
3 | 3 | import torch |
4 | 4 | import onnx |
@@ -157,7 +157,7 @@ def forward(self, x): |
157 | 157 |
|
158 | 158 | class TestExporter: |
159 | 159 | def setup_method(self): |
160 | | - set_seed(0) |
| 160 | + set_seed(1) |
161 | 161 | self.in_channels = 3 |
162 | 162 | self.out_channels = 4 |
163 | 163 | self.onnx_file = f"./tmp_model_{np.random.randint(1e10)}.onnx" |
@@ -291,10 +291,10 @@ def _assert_fq_quant_params_match(self, quantized_model, onnx_model_dict, a_qmet |
291 | 291 | @pytest.mark.parametrize('w_qmethod', [mctq.QuantizationMethod.POWER_OF_TWO, |
292 | 292 | mctq.QuantizationMethod.SYMMETRIC, |
293 | 293 | mctq.QuantizationMethod.UNIFORM]) |
294 | | - @pytest.mark.parametrize('a_qmethod, tol', [(mctq.QuantizationMethod.POWER_OF_TWO, 1e-8), |
295 | | - (mctq.QuantizationMethod.SYMMETRIC, 1e-2), |
296 | | - (mctq.QuantizationMethod.UNIFORM, 1e-2)]) |
297 | | - @pytest.mark.parametrize('abits', [8, 16]) |
| 294 | + @pytest.mark.parametrize('a_qmethod', [mctq.QuantizationMethod.POWER_OF_TWO, |
| 295 | + mctq.QuantizationMethod.SYMMETRIC, |
| 296 | + mctq.QuantizationMethod.UNIFORM]) |
| 297 | + @pytest.mark.parametrize('abits, tol', [(8, 1e-4), (16, 1e-3)]) |
298 | 298 | def test_mct_ptq_and_exporter_mctq(self, w_qmethod, abits, a_qmethod, tol): |
299 | 299 | quantized_model = self._run_mct(self.get_model(), self.representative_dataset(1), abits, a_qmethod, w_qmethod) |
300 | 300 | onnx_model_dict = self._run_exporter(quantized_model, self.representative_dataset(1), QuantizationFormat.MCTQ) |
|
0 commit comments