1919from executorch .extension .llm .export .builder import DType
2020
2121from sentencepiece import SentencePieceProcessor
22- from torch .nn .modules import linear
2322
2423try :
2524 from fairseq2 .nn .embedding import (
@@ -74,9 +73,17 @@ def quantize( # noqa C901
7473 # Add quantization mode options here: group size, bit width, etc.
7574 return WeightOnlyInt8QuantHandler (model ).quantized_model ()
7675 elif qmode .startswith ("torchao:" ):
77- import os
7876 import glob
79- libs = glob .glob (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*" )))
77+ import os
78+
79+ libs = glob .glob (
80+ os .path .abspath (
81+ os .path .join (
82+ os .path .dirname (__file__ ),
83+ "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*" ,
84+ )
85+ )
86+ )
8087 assert len (libs ) == 1 , f"Expected 1 library but got { len (libs )} "
8188 logging .info (f"Loading custom ops library: { libs [0 ]} " )
8289 torch .ops .load_library (libs [0 ])
@@ -89,24 +96,32 @@ def quantize( # noqa C901
8996
9097 linear_matches = re .findall (linear_pattern , qmode )
9198 if linear_matches :
92- assert len (linear_matches ) == 1 , f"Expected 1 match but got { len (linear_matches )} "
99+ assert (
100+ len (linear_matches ) == 1
101+ ), f"Expected 1 match but got { len (linear_matches )} "
93102 bitwidth = int (linear_matches [0 ][0 ])
94103 groupsize = int (linear_matches [0 ][1 ])
95- from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
104+ from torchao .experimental .quant_api import (
105+ Int8DynActIntxWeightLinearQuantizer ,
106+ )
107+
96108 model = Int8DynActIntxWeightLinearQuantizer (
97109 device = "cpu" ,
98110 precision = torch_dtype ,
99111 groupsize = groupsize ,
100112 bitwidth = bitwidth ,
101113 has_weight_zeros = False ,
102114 ).quantize (model )
103-
115+
104116 embedding_matches = re .findall (embedding_pattern , qmode )
105117 if embedding_matches :
106- assert len (embedding_matches ) == 1 , f"Expected 1 match but got { len (embedding_matches )} "
118+ assert (
119+ len (embedding_matches ) == 1
120+ ), f"Expected 1 match but got { len (embedding_matches )} "
107121 bitwidth = int (embedding_matches [0 ][0 ])
108122 groupsize = int (embedding_matches [0 ][1 ])
109123 from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
124+
110125 model = IntxWeightEmbeddingQuantizer (
111126 device = "cpu" ,
112127 precision = torch_dtype ,
0 commit comments