@@ -94,40 +94,42 @@ def quantize( # noqa C901
9494 embedding_pattern = r"emb.(\d+),(\d+)"
9595 linear_pattern = r"lin8da.(\d+),(\d+)"
9696
97- linear_matches = re .findall (linear_pattern , qmode )
98- if linear_matches :
97+ matches = re .findall (linear_pattern , qmode )
98+ if matches :
9999 assert (
100- len (linear_matches ) == 1
101- ), f"Expected 1 match but got { len (linear_matches )} "
102- bitwidth = int (linear_matches [0 ][0 ])
103- groupsize = int (linear_matches [0 ][1 ])
100+ len (matches ) == 1
101+ ), f"Expected 1 match for linear_pattern but got { len (matches )} "
102+ bitwidth = int (matches [0 ][0 ])
103+ groupsize = int (matches [0 ][1 ])
104104 from torchao .experimental .quant_api import (
105105 Int8DynActIntxWeightLinearQuantizer ,
106106 )
107107
108- model = Int8DynActIntxWeightLinearQuantizer (
109- device = "cpu" ,
110- precision = torch_dtype ,
111- groupsize = groupsize ,
112- bitwidth = bitwidth ,
113- has_weight_zeros = False ,
114- ).quantize (model )
108+ with torch .no_grad ():
109+ model = Int8DynActIntxWeightLinearQuantizer (
110+ device = "cpu" ,
111+ precision = torch_dtype ,
112+ groupsize = groupsize ,
113+ bitwidth = bitwidth ,
114+ has_weight_zeros = False ,
115+ ).quantize (model )
115116
116- embedding_matches = re .findall (embedding_pattern , qmode )
117- if embedding_matches :
117+ matches = re .findall (embedding_pattern , qmode )
118+ if matches :
118119 assert (
119- len (embedding_matches ) == 1
120- ), f"Expected 1 match but got { len (embedding_matches )} "
121- bitwidth = int (embedding_matches [0 ][0 ])
122- groupsize = int (embedding_matches [0 ][1 ])
120+ len (matches ) == 1
121+ ), f"Expected 1 match for embedding_pattern but got { len (matches )} "
122+ bitwidth = int (matches [0 ][0 ])
123+ groupsize = int (matches [0 ][1 ])
123124 from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
124125
125- model = IntxWeightEmbeddingQuantizer (
126- device = "cpu" ,
127- precision = torch_dtype ,
128- bitwidth = bitwidth ,
129- groupsize = groupsize ,
130- ).quantize (model )
126+ with torch .no_grad ():
127+ model = IntxWeightEmbeddingQuantizer (
128+ device = "cpu" ,
129+ precision = torch_dtype ,
130+ bitwidth = bitwidth ,
131+ groupsize = groupsize ,
132+ ).quantize (model )
131133
132134 if verbose :
133135 print ("quantized model:" , model )
0 commit comments