2626from nncf .quantization .algorithms .min_max .torch_fx_backend import FXMinMaxAlgoBackend
2727from nncf .quantization .fake_quantize import calculate_quantizer_parameters
2828from nncf .tensor import Tensor
29+ from nncf .tensor .definitions import TensorDataType
2930
3031INPUT_SHAPE = (2 , 3 , 4 , 5 )
3132
@@ -79,7 +80,7 @@ class CaseQuantParams:
7980
8081
8182@pytest .mark .parametrize ("case_to_test" , SYM_CASES )
82- @pytest .mark .parametrize ("dtype" , [IntDtype . UINT8 , IntDtype . INT8 ])
83+ @pytest .mark .parametrize ("dtype" , [TensorDataType . uint8 , TensorDataType . int8 ])
8384def test_quantizer_params_sym (case_to_test : CaseQuantParams , dtype : Optional [IntDtype ]):
8485 per_ch = case_to_test .per_channel
8586 narrow_range = case_to_test .narrow_range
@@ -97,7 +98,7 @@ def test_quantizer_params_sym(case_to_test: CaseQuantParams, dtype: Optional[Int
9798 quantizer = _get_quantizer (case_to_test , qconfig )
9899 assert quantizer .qscheme is torch .per_channel_symmetric if case_to_test .per_channel else torch .per_tensor_symmetric
99100
100- signed = signedness_to_force or dtype is IntDtype . INT8
101+ signed = signedness_to_force or dtype is TensorDataType . int8
101102 if signed :
102103 assert torch .allclose (quantizer .zero_point , torch .tensor (0 , dtype = torch .int8 ))
103104 else :
@@ -380,7 +381,7 @@ def test_quantizer_params_sym_nr(case_to_test: CaseQuantParams, ref_signed: bool
380381
381382
382383@pytest .mark .parametrize ("case_to_test,ref_zp" , ASYM_CASES )
383- @pytest .mark .parametrize ("dtype" , [IntDtype . UINT8 , IntDtype . INT8 ])
384+ @pytest .mark .parametrize ("dtype" , [TensorDataType . uint8 , TensorDataType . int8 ])
384385def test_quantizer_params_asym (case_to_test : CaseQuantParams , ref_zp : Union [int , list [int ]], dtype : Optional [IntDtype ]):
385386 per_ch = case_to_test .per_channel
386387 narrow_range = case_to_test .narrow_range
@@ -397,7 +398,7 @@ def test_quantizer_params_asym(case_to_test: CaseQuantParams, ref_zp: Union[int,
397398 quantizer = _get_quantizer (case_to_test , qconfig )
398399 assert quantizer .qscheme is torch .per_channel_affine if case_to_test .per_channel else torch .per_tensor_affine
399400
400- signed = dtype is IntDtype . INT8
401+ signed = dtype is TensorDataType . int8
401402 ref_zp = torch .tensor (ref_zp )
402403 if not signed :
403404 ref_zp += 127 if narrow_range else 128
0 commit comments