1717from executorch .extension .llm .export .builder import DType
1818
1919from sentencepiece import SentencePieceProcessor
20- from torch .nn .modules import linear
2120
2221try :
2322 from fairseq2 .nn .embedding import (
@@ -72,9 +71,17 @@ def quantize( # noqa: C901
7271 # Add quantization mode options here: group size, bit width, etc.
7372 return WeightOnlyInt8QuantHandler (model ).quantized_model ()
7473 elif qmode .startswith ("torchao:" ):
75- import os
7674 import glob
77- libs = glob .glob (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*" )))
75+ import os
76+
77+ libs = glob .glob (
78+ os .path .abspath (
79+ os .path .join (
80+ os .path .dirname (__file__ ),
81+ "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*" ,
82+ )
83+ )
84+ )
7885 assert len (libs ) == 1 , f"Expected 1 library but got { len (libs )} "
7986 logging .info (f"Loading custom ops library: { libs [0 ]} " )
8087 torch .ops .load_library (libs [0 ])
@@ -87,24 +94,32 @@ def quantize( # noqa: C901
8794
8895 linear_matches = re .findall (linear_pattern , qmode )
8996 if linear_matches :
90- assert len (linear_matches ) == 1 , f"Expected 1 match but got { len (linear_matches )} "
97+ assert (
98+ len (linear_matches ) == 1
99+ ), f"Expected 1 match but got { len (linear_matches )} "
91100 bitwidth = int (linear_matches [0 ][0 ])
92101 groupsize = int (linear_matches [0 ][1 ])
93- from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
102+ from torchao .experimental .quant_api import (
103+ Int8DynActIntxWeightLinearQuantizer ,
104+ )
105+
94106 model = Int8DynActIntxWeightLinearQuantizer (
95107 device = "cpu" ,
96108 precision = torch_dtype ,
97109 groupsize = groupsize ,
98110 bitwidth = bitwidth ,
99111 has_weight_zeros = False ,
100112 ).quantize (model )
101-
113+
102114 embedding_matches = re .findall (embedding_pattern , qmode )
103115 if embedding_matches :
104- assert len (embedding_matches ) == 1 , f"Expected 1 match but got { len (embedding_matches )} "
116+ assert (
117+ len (embedding_matches ) == 1
118+ ), f"Expected 1 match but got { len (embedding_matches )} "
105119 bitwidth = int (embedding_matches [0 ][0 ])
106120 groupsize = int (embedding_matches [0 ][1 ])
107121 from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
122+
108123 model = IntxWeightEmbeddingQuantizer (
109124 device = "cpu" ,
110125 precision = torch_dtype ,
0 commit comments