1919from  executorch .extension .llm .export .builder  import  DType 
2020
2121from  sentencepiece  import  SentencePieceProcessor 
22- from  torch .nn .modules  import  linear 
2322
2423try :
2524    from  fairseq2 .nn .embedding  import  (
@@ -74,9 +73,17 @@ def quantize(  # noqa C901
7473        # Add quantization mode options here: group size, bit width, etc. 
7574        return  WeightOnlyInt8QuantHandler (model ).quantized_model ()
7675    elif  qmode .startswith ("torchao:" ):
77-         import  os 
7876        import  glob 
79-         libs  =  glob .glob (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*" )))
77+         import  os 
78+ 
79+         libs  =  glob .glob (
80+             os .path .abspath (
81+                 os .path .join (
82+                     os .path .dirname (__file__ ),
83+                     "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*" ,
84+                 )
85+             )
86+         )
8087        assert  len (libs ) ==  1 , f"Expected 1 library but got { len (libs )}  
8188        logging .info (f"Loading custom ops library: { libs [0 ]}  )
8289        torch .ops .load_library (libs [0 ])
@@ -89,24 +96,32 @@ def quantize(  # noqa C901
8996
9097        linear_matches  =  re .findall (linear_pattern , qmode )
9198        if  linear_matches :
92-             assert  len (linear_matches ) ==  1 , f"Expected 1 match but got { len (linear_matches )}  
99+             assert  (
100+                 len (linear_matches ) ==  1 
101+             ), f"Expected 1 match but got { len (linear_matches )}  
93102            bitwidth  =  int (linear_matches [0 ][0 ])
94103            groupsize  =  int (linear_matches [0 ][1 ])
95-             from  torchao .experimental .quant_api  import  Int8DynActIntxWeightLinearQuantizer 
104+             from  torchao .experimental .quant_api  import  (
105+                 Int8DynActIntxWeightLinearQuantizer ,
106+             )
107+ 
96108            model  =  Int8DynActIntxWeightLinearQuantizer (
97109                device = "cpu" ,
98110                precision = torch_dtype ,
99111                groupsize = groupsize ,
100112                bitwidth = bitwidth ,
101113                has_weight_zeros = False ,
102114            ).quantize (model )
103-          
115+ 
104116        embedding_matches  =  re .findall (embedding_pattern , qmode )
105117        if  embedding_matches :
106-             assert  len (embedding_matches ) ==  1 , f"Expected 1 match but got { len (embedding_matches )}  
118+             assert  (
119+                 len (embedding_matches ) ==  1 
120+             ), f"Expected 1 match but got { len (embedding_matches )}  
107121            bitwidth  =  int (embedding_matches [0 ][0 ])
108122            groupsize  =  int (embedding_matches [0 ][1 ])
109123            from  torchao .experimental .quant_api  import  IntxWeightEmbeddingQuantizer 
124+ 
110125            model  =  IntxWeightEmbeddingQuantizer (
111126                device = "cpu" ,
112127                precision = torch_dtype ,
0 commit comments