2121import torch
2222import torch .nn as nn
2323
24+ from tico .quantization .algorithm .fpi_gptq .util import iterate_GPTQ
2425
2526def quantize (x , scale , zero , maxq ):
2627 if maxq < 0 :
@@ -101,7 +102,7 @@ def find_params(self, x, weight=False):
101102 else :
102103 self .zero = torch .round (- xmin / self .scale )
103104
104- if self .mse is not None :
105+ if self .mse is not None and self . mse != "smse_for_gptq" and self . mse != "mse_for_gptq" :
105106 best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
106107 for i in range (int (self .maxshrink * self .grid )):
107108 p = 1 - i / self .grid
@@ -112,12 +113,10 @@ def find_params(self, x, weight=False):
112113 q = quantize (x , scale1 .unsqueeze (1 ), zero1 .unsqueeze (1 ), self .maxq )
113114 q -= x
114115 q .abs_ ()
115- if self .mse == "smse" : # senstitivity weighted mse
116- # in case senstitivity is a second order derivatives of some global loss
117- # (q**2) * self.sensitivity is just a global loss change due to quantization.
116+ if self .mse == "smse" :
118117 q = (q ** 2 ) * self .sensitivity .to (
119118 q .device
120- ) # estimate global target change
119+ ) # sensitivity weighted `mse`
121120 else :
122121 assert self .mse == "mse"
123122 q .pow_ (self .norm )
@@ -127,6 +126,7 @@ def find_params(self, x, weight=False):
127126 best [tmp ] = err [tmp ]
128127 self .scale [tmp ] = scale1 [tmp ]
129128 self .zero [tmp ] = zero1 [tmp ]
129+
130130 if not self .perchannel :
131131 if weight :
132132 tmp = shape [0 ]
@@ -151,6 +151,82 @@ def find_params(self, x, weight=False):
151151 self .scale = self .scale .unsqueeze (0 )
152152 self .zero = self .zero .unsqueeze (0 )
153153
154+ def update (self , x , Hinv , perm ):
155+ if self .mse is None or (
156+ self .mse != "smse_for_gptq" and self .mse != "mse_for_gptq"
157+ ):
158+ return
159+
160+ shape = x .shape
161+ if self .perchannel :
162+ x = x .flatten (1 )
163+ else :
164+ x = x .flatten ().unsqueeze (0 )
165+
166+ dev = x .device
167+ tmp = torch .zeros (x .shape [0 ], device = dev )
168+ xmin = torch .minimum (x .min (1 )[0 ], tmp )
169+ xmax = torch .maximum (x .max (1 )[0 ], tmp )
170+
171+ if self .sym :
172+ xmax = torch .maximum (torch .abs (xmin ), xmax )
173+ tmp = xmin < 0
174+ if torch .any (tmp ):
175+ xmin [tmp ] = - xmax [tmp ]
176+ tmp = (xmin == 0 ) & (xmax == 0 )
177+ xmin [tmp ] = - 1
178+ xmax [tmp ] = + 1
179+ if self .maxq < 0 :
180+ self .scale = xmax
181+ self .zero = xmin
182+ else :
183+ self .scale = (xmax - xmin ) / self .maxq
184+ if self .sym :
185+ self .zero = torch .full_like (self .scale , (self .maxq + 1 ) / 2 ) # type: ignore[arg-type]
186+ else :
187+ self .zero = torch .round (- xmin / self .scale )
188+
189+ sensitivity = None
190+ if self .sensitivity is not None :
191+ sensitivity = self .sensitivity .to (Hinv .dtype ).to (dev )
192+ if perm is not None :
193+ sensitivity = sensitivity [:, perm .to (dev )]
194+
195+ num_of_iters = 15
196+ best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
197+ for i in range (int (self .maxshrink * self .grid )):
198+ p = 1 - i / self .grid
199+ xmin1 = p * xmin
200+ xmax1 = p * xmax
201+ scale1 = (xmax1 - xmin1 ) / self .maxq
202+ zero1 = torch .round (- xmin1 / scale1 ) if not self .sym else self .zero
203+ q , pre_q = iterate_GPTQ (
204+ scale1 .unsqueeze (1 ),
205+ zero1 .unsqueeze (1 ),
206+ self .maxq ,
207+ x ,
208+ Hinv ,
209+ max_num_of_iters = num_of_iters ,
210+ )
211+ if sensitivity is not None :
212+ assert self .mse == "smse_for_gptq"
213+ err = ((q - pre_q ) ** 2 ) * sensitivity .to (q .device )
214+ else :
215+ assert self .mse == "mse_for_gptq"
216+ # err = torch.abs((q - pre_q)).pow_(self.norm)
217+ err = ((q - pre_q ) / torch .diag (Hinv )) ** 2
218+ err = err
219+ err = torch .sum (err , 1 )
220+ tmp = err < best
221+ if torch .any (tmp ):
222+ best [tmp ] = err [tmp ]
223+ self .scale [tmp ] = scale1 [tmp ]
224+ self .zero [tmp ] = zero1 [tmp ]
225+
226+ shape = [- 1 ] + [1 ] * (len (shape ) - 1 )
227+ self .scale = self .scale .reshape (shape )
228+ self .zero = self .zero .reshape (shape )
229+
154230 def quantize (self , x ):
155231 if self .ready ():
156232 return quantize (x , self .scale , self .zero , self .maxq )
0 commit comments