22#
33# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44
5- import copy
65import sys
76import unittest
87
1514from executorch .backends .apple .coreml .compiler import CoreMLBackend
1615from executorch .backends .apple .coreml .partition import CoreMLPartitioner
1716from executorch .runtime import Runtime
18- from torchao .quantization import quantize_ , PerGroup , PerAxis , IntxWeightOnlyConfig
17+ from torchao .quantization import IntxWeightOnlyConfig , PerAxis , PerGroup , quantize_
1918
2019_TEST_RUNTIME = sys .platform == "darwin"
2120
@@ -30,10 +29,12 @@ def _coreml_partitioner(self):
3029 return CoreMLPartitioner (compile_specs = compile_specs )
3130
3231 def _get_test_model (self ):
33- model = torch .nn .Sequential (torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 128 ), torch .nn .ReLU ())
32+ model = torch .nn .Sequential (
33+ torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 128 ), torch .nn .ReLU ()
34+ )
3435 example_inputs = (torch .LongTensor ([0 ]),)
3536 return model , example_inputs
36-
37+
3738 def _compare_outputs (self , executorch_program , eager_program , example_inputs ):
3839 if not _TEST_RUNTIME :
3940 return
@@ -45,10 +46,14 @@ def _compare_outputs(self, executorch_program, eager_program, example_inputs):
4546 self .assertTrue (
4647 torch .allclose (et_outputs , eager_outputs , atol = 1e-02 , rtol = 1e-02 )
4748 )
48-
49+
4950 def test_dequantize_affine_b4w_embedding (self ):
5051 model , example_inputs = self ._get_test_model ()
51- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )), lambda m , fqn : isinstance (m , torch .nn .Embedding ))
52+ quantize_ (
53+ model ,
54+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )),
55+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
56+ )
5257 ep = torch .export .export (model , example_inputs )
5358 delegated_program = executorch .exir .to_edge_transform_and_lower (
5459 ep ,
@@ -65,7 +70,10 @@ def test_dequantize_affine_b4w_embedding(self):
6570
6671 def test_dequantize_affine_b4w_linear (self ):
6772 model , example_inputs = self ._get_test_model ()
68- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )))
73+ quantize_ (
74+ model ,
75+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )),
76+ )
6977 ep = torch .export .export (model , example_inputs )
7078 delegated_program = executorch .exir .to_edge_transform_and_lower (
7179 ep ,
@@ -82,7 +90,11 @@ def test_dequantize_affine_b4w_linear(self):
8290
8391 def test_dequantize_affine_c4w_embedding (self ):
8492 model , example_inputs = self ._get_test_model ()
85- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 )), lambda m , fqn : isinstance (m , torch .nn .Embedding ))
93+ quantize_ (
94+ model ,
95+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 )),
96+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
97+ )
8698 ep = torch .export .export (model , example_inputs )
8799 delegated_program = executorch .exir .to_edge_transform_and_lower (
88100 ep ,
@@ -99,7 +111,9 @@ def test_dequantize_affine_c4w_embedding(self):
99111
100112 def test_dequantize_affine_c4w_linear (self ):
101113 model , example_inputs = self ._get_test_model ()
102- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 )))
114+ quantize_ (
115+ model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 ))
116+ )
103117 ep = torch .export .export (model , example_inputs )
104118 delegated_program = executorch .exir .to_edge_transform_and_lower (
105119 ep ,
@@ -113,11 +127,18 @@ def test_dequantize_affine_c4w_linear(self):
113127 ], f"Got unexpected node target after delegation: { node .target .__name__ } "
114128 et_prog = delegated_program .to_executorch ()
115129 self ._compare_outputs (et_prog , model , example_inputs )
116-
130+
117131 def test_dequantize_affine_c8w_embedding_b4w_linear (self ):
118132 model , example_inputs = self ._get_test_model ()
119- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int8 , granularity = PerAxis (0 )), lambda m , fqn : isinstance (m , torch .nn .Embedding ))
120- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )))
133+ quantize_ (
134+ model ,
135+ IntxWeightOnlyConfig (weight_dtype = torch .int8 , granularity = PerAxis (0 )),
136+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
137+ )
138+ quantize_ (
139+ model ,
140+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )),
141+ )
121142 ep = torch .export .export (model , example_inputs )
122143 delegated_program = executorch .exir .to_edge_transform_and_lower (
123144 ep ,
0 commit comments