@@ -94,40 +94,42 @@ def quantize(  # noqa C901
9494        embedding_pattern  =  r"emb.(\d+),(\d+)" 
9595        linear_pattern  =  r"lin8da.(\d+),(\d+)" 
9696
97-         linear_matches  =  re .findall (linear_pattern , qmode )
98-         if  linear_matches :
97+         matches  =  re .findall (linear_pattern , qmode )
98+         if  matches :
9999            assert  (
100-                 len (linear_matches ) ==  1 
101-             ), f"Expected 1 match but got { len (linear_matches )}  
102-             bitwidth  =  int (linear_matches [0 ][0 ])
103-             groupsize  =  int (linear_matches [0 ][1 ])
100+                 len (matches ) ==  1 
101+             ), f"Expected 1 match for linear_pattern  but got { len (matches )}  
102+             bitwidth  =  int (matches [0 ][0 ])
103+             groupsize  =  int (matches [0 ][1 ])
104104            from  torchao .experimental .quant_api  import  (
105105                Int8DynActIntxWeightLinearQuantizer ,
106106            )
107107
108-             model  =  Int8DynActIntxWeightLinearQuantizer (
109-                 device = "cpu" ,
110-                 precision = torch_dtype ,
111-                 groupsize = groupsize ,
112-                 bitwidth = bitwidth ,
113-                 has_weight_zeros = False ,
114-             ).quantize (model )
108+             with  torch .no_grad ():
109+                 model  =  Int8DynActIntxWeightLinearQuantizer (
110+                     device = "cpu" ,
111+                     precision = torch_dtype ,
112+                     groupsize = groupsize ,
113+                     bitwidth = bitwidth ,
114+                     has_weight_zeros = False ,
115+                 ).quantize (model )
115116
116-         embedding_matches  =  re .findall (embedding_pattern , qmode )
117-         if  embedding_matches :
117+         matches  =  re .findall (embedding_pattern , qmode )
118+         if  matches :
118119            assert  (
119-                 len (embedding_matches ) ==  1 
120-             ), f"Expected 1 match but got { len (embedding_matches )}  
121-             bitwidth  =  int (embedding_matches [0 ][0 ])
122-             groupsize  =  int (embedding_matches [0 ][1 ])
120+                 len (matches ) ==  1 
121+             ), f"Expected 1 match for embedding_pattern  but got { len (matches )}  
122+             bitwidth  =  int (matches [0 ][0 ])
123+             groupsize  =  int (matches [0 ][1 ])
123124            from  torchao .experimental .quant_api  import  IntxWeightEmbeddingQuantizer 
124125
125-             model  =  IntxWeightEmbeddingQuantizer (
126-                 device = "cpu" ,
127-                 precision = torch_dtype ,
128-                 bitwidth = bitwidth ,
129-                 groupsize = groupsize ,
130-             ).quantize (model )
126+             with  torch .no_grad ():
127+                 model  =  IntxWeightEmbeddingQuantizer (
128+                     device = "cpu" ,
129+                     precision = torch_dtype ,
130+                     bitwidth = bitwidth ,
131+                     groupsize = groupsize ,
132+                 ).quantize (model )
131133
132134        if  verbose :
133135            print ("quantized model:" , model )
0 commit comments