44# This source code is licensed under the BSD 3-Clause license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7- import  platform 
8- import  sys 
97from  copy  import  deepcopy 
108
119import  pytest 
1513    StretchedIntxWeightConfig ,
1614    StretchedUnifTorchaoQuantizer ,
1715)
18- from  torchao .prototype .quantization .dynamic_activation_lut  import  (
19-     StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig ,
16+ from  torchao .prototype .quantization .int8_lut_tensor . int8_lut_tensor  import  (
17+     _is_kernel_library_loaded ,
2018)
19+ from  torchao .prototype .tensor_conversion .api  import  _convert_model_for_aarch64 
2120from  torchao .quantization  import  quantize_ 
2221from  torchao .quantization .granularity  import  PerAxis , PerGroup 
23- from  torchao .quantization .quant_api  import  _is_linear 
2422from  torchao .quantization .utils  import  compute_error 
2523
26- is_arm64_mac  =  sys .platform  ==  "darwin"  and  platform .machine () ==  "arm64" 
27- 
2824
2925class  ToyLinearModel (torch .nn .Module ):
3026    def  __init__ (self , d1 = 512 , d2 = 256 , d3 = 128 , d4 = 8 ):
@@ -59,7 +55,9 @@ def run_before_and_after_tests():
5955@pytest .mark .parametrize ("granularity" , [PerGroup (32 ), PerAxis (0 )]) 
6056@pytest .mark .parametrize ("bit_width" , [1 , 2 , 3 , 4 ]) 
6157@pytest .mark .parametrize ("lead_dim" , [(5 ,), (2 , 3 )]) 
62- @pytest .mark .skipif (not  is_arm64_mac , reason = "requires arm64 mac" ) 
58+ @pytest .mark .skipif ( 
59+     not  _is_kernel_library_loaded (), reason = "Kernel library is not loaded"  
60+ ) 
6361def  test_parq_conversion (dtype , granularity , bit_width , lead_dim ):
6462    torch .manual_seed (0 )
6563    quantizer  =  StretchedUnifTorchaoQuantizer (bit_width )
@@ -68,38 +66,22 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
6866        quant_min = quantizer .quant_min ,
6967        quant_max = quantizer .quant_max ,
7068        granularity = granularity ,
71-         activation_quantization = None ,
72-         version = 1 ,
69+         activation_quantization = "int8_asym_per_token" ,
7370    )
7471
7572    parq_model  =  ToyLinearModel (128 , 256 , 128 , 1 ).to (dtype )
7673    activations  =  parq_model .example_inputs (lead_dim = lead_dim , dtype = dtype )
77-     parq_model_with_dyn_quant  =  deepcopy (parq_model )
7874    quantize_ (parq_model , config )
7975
80-     # Apply dynamic activation to parq model.  This will serve as the LUT reference 
81-     dyn_act_config  =  deepcopy (config )
82-     dyn_act_config .activation_quantization  =  "int8_asym_per_token" 
83-     quantize_ (parq_model_with_dyn_quant , dyn_act_config , filter_fn = _is_linear )
84- 
8576    # Convert PARQ model to lowbit LUT model 
8677    lut_model  =  deepcopy (parq_model )
87-     conversion_config  =  (
88-         StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig (
89-             config .b , config .granularity 
90-         )
91-     )
92-     quantize_ (lut_model , conversion_config , filter_fn = conversion_config .get_filter_fn ())
78+     _convert_model_for_aarch64 (lut_model , tensor_type = "int8_lut_tensor" )
9379
9480    # Run both models and compare 
9581    parq_out  =  parq_model (activations )
96-     parq_with_dyn_quant_out  =  parq_model_with_dyn_quant (activations )
9782    lut_out  =  lut_model (activations )
9883
99-     sqnr  =  compute_error (parq_out , parq_with_dyn_quant_out ).item ()
100-     assert  sqnr  >  20.0 , f"sqnr { sqnr }   is too low" 
101- 
102-     sqnr  =  compute_error (lut_out , parq_with_dyn_quant_out ).item ()
84+     sqnr  =  compute_error (parq_out , lut_out ).item ()
10385    if  dtype  ==  torch .float32 :
10486        assert  sqnr  >  40.0 , f"sqnr { sqnr }   is too low" 
10587    elif  dtype  ==  torch .bfloat16 :
@@ -112,32 +94,27 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
11294@pytest .mark .parametrize ("granularity" , [PerGroup (32 ), PerAxis (0 )]) 
11395@pytest .mark .parametrize ("bit_width" , [1 , 2 , 3 , 4 ]) 
11496@pytest .mark .parametrize ("lead_dim" , [(5 ,), (2 , 3 )]) 
115- @pytest .mark .skipif (not  is_arm64_mac , reason = "requires arm64 mac" ) 
97+ @pytest .mark .skipif ( 
98+     not  _is_kernel_library_loaded (), reason = "Kernel library is not loaded"  
99+ ) 
116100def  test_export (dtype , granularity , bit_width , lead_dim ):
117101    quantizer  =  StretchedUnifTorchaoQuantizer (bit_width )
118102    config  =  StretchedIntxWeightConfig (
119103        b = bit_width ,
120104        quant_min = quantizer .quant_min ,
121105        quant_max = quantizer .quant_max ,
122106        granularity = granularity ,
123-         activation_quantization = None ,
124-         version = 1 ,
107+         activation_quantization = "int8_asym_per_token" ,
125108    )
126109
127110    parq_model  =  ToyLinearModel (128 , 256 , 128 , 8 ).to (dtype )
128111    activations  =  parq_model .example_inputs (lead_dim = lead_dim )
129112    quantize_ (parq_model , config )
130113
131-     conversion_config  =  (
132-         StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig (
133-             config .b , config .granularity 
134-         )
135-     )
136-     quantize_ (
137-         parq_model , conversion_config , filter_fn = conversion_config .get_filter_fn ()
138-     )
114+     _convert_model_for_aarch64 (parq_model )
139115
140116    ep  =  torch .export .export (parq_model , (activations ,))
117+ 
141118    assert  (
142119        f"torch.ops.torchao._linear_8bit_act_{ bit_width }  bit_weight.default" 
143120        in  ep .graph_module .code 
0 commit comments