2121import torch
2222import torch .nn as nn
2323
24+ from tico .quantization .algorithm .fpi_gptq .util import iterate_GPTQ
2425
2526def quantize (x , scale , zero , maxq ):
2627 if maxq < 0 :
@@ -41,11 +42,12 @@ def configure(
4142 bits ,
4243 perchannel = False ,
4344 sym = True ,
44- mse = False ,
45+ mse = None ,
4546 norm = 2.4 ,
4647 grid = 100 ,
4748 maxshrink = 0.8 ,
4849 trits = False ,
50+ sensitivity = None
4951 ):
5052 self .maxq = torch .tensor (2 ** bits - 1 )
5153 self .perchannel = perchannel
@@ -54,6 +56,7 @@ def configure(
5456 self .norm = norm
5557 self .grid = grid
5658 self .maxshrink = maxshrink
59+ self .sensitivity = sensitivity
5760 if trits :
5861 self .maxq = torch .tensor (- 1 )
5962
@@ -99,7 +102,10 @@ def find_params(self, x, weight=False):
99102 else :
100103 self .zero = torch .round (- xmin / self .scale )
101104
102- if self .mse :
105+ if self .mse is not None and self .mse != "smse_for_gptq" :
106+ if self .mse == "smse" :
107+ self .maxshrink = 0.5
108+
103109 best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
104110 for i in range (int (self .maxshrink * self .grid )):
105111 p = 1 - i / self .grid
@@ -110,13 +116,19 @@ def find_params(self, x, weight=False):
110116 q = quantize (x , scale1 .unsqueeze (1 ), zero1 .unsqueeze (1 ), self .maxq )
111117 q -= x
112118 q .abs_ ()
113- q .pow_ (self .norm )
119+ if self .mse == "smse" :
120+ q = (q ** 2 ) * self .sensitivity .to (
121+ q .device
122+ ) # sensitivity weighted `mse`
123+ else :
124+ q .pow_ (self .norm )
114125 err = torch .sum (q , 1 )
115126 tmp = err < best
116127 if torch .any (tmp ):
117128 best [tmp ] = err [tmp ]
118129 self .scale [tmp ] = scale1 [tmp ]
119130 self .zero [tmp ] = zero1 [tmp ]
131+
120132 if not self .perchannel :
121133 if weight :
122134 tmp = shape [0 ]
@@ -140,7 +152,84 @@ def find_params(self, x, weight=False):
140152 if len (shape ) == 2 :
141153 self .scale = self .scale .unsqueeze (0 )
142154 self .zero = self .zero .unsqueeze (0 )
155+
156+ def update (self , x , Hinv , perm ):
157+ if self .mse is None or (self .mse != "smse_for_gptq" and self .mse != "mse_for_gptq" ):
158+ return
159+
160+ shape = x .shape
161+ if self .perchannel :
162+ x = x .flatten (1 )
163+ else :
164+ x = x .flatten ().unsqueeze (0 )
165+
166+ dev = x .device
167+ tmp = torch .zeros (x .shape [0 ], device = dev )
168+ xmin = torch .minimum (x .min (1 )[0 ], tmp )
169+ xmax = torch .maximum (x .max (1 )[0 ], tmp )
143170
171+ if self .sym :
172+ xmax = torch .maximum (torch .abs (xmin ), xmax )
173+ tmp = xmin < 0
174+ if torch .any (tmp ):
175+ xmin [tmp ] = - xmax [tmp ]
176+ tmp = (xmin == 0 ) & (xmax == 0 )
177+ xmin [tmp ] = - 1
178+ xmax [tmp ] = + 1
179+ if self .maxq < 0 :
180+ self .scale = xmax
181+ self .zero = xmin
182+ else :
183+ self .scale = (xmax - xmin ) / self .maxq
184+ if self .sym :
185+ self .zero = torch .full_like (self .scale , (self .maxq + 1 ) / 2 ) # type: ignore[arg-type]
186+ else :
187+ self .zero = torch .round (- xmin / self .scale )
188+
189+ self .maxshrink = 0.5
190+ sensitivity = None
191+ if self .sensitivity is not None :
192+ sensitivity = self .sensitivity .to (Hinv .dtype ).to (dev )
193+ if perm is not None :
194+ sensitivity = sensitivity [:, perm .to (dev )]
195+
196+ num_of_iters = 15
197+ best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
198+ for i in range (int (self .maxshrink * self .grid )):
199+ p = 1 - i / self .grid
200+ xmin1 = p * xmin
201+ xmax1 = p * xmax
202+ scale1 = (xmax1 - xmin1 ) / self .maxq
203+ zero1 = torch .round (- xmin1 / scale1 ) if not self .sym else self .zero
204+ q , pre_q = iterate_GPTQ (
205+ scale1 .unsqueeze (1 ),
206+ zero1 .unsqueeze (1 ),
207+ self .maxq ,
208+ x ,
209+ Hinv ,
210+ max_num_of_iters = num_of_iters ,
211+ )
212+ if sensitivity is not None :
213+ assert self .mse == "smse_for_gptq"
214+ err = ((q - pre_q )** 2 ) * sensitivity .to (
215+ q .device
216+ )
217+ else :
218+ assert self .mse == "mse_for_gptq"
219+ #err = torch.abs((q - pre_q)).pow_(self.norm)
220+ err = ((q - pre_q )/ torch .diag (Hinv ))** 2
221+ err = err
222+ err = torch .sum (err , 1 )
223+ tmp = err < best
224+ if torch .any (tmp ):
225+ best [tmp ] = err [tmp ]
226+ self .scale [tmp ] = scale1 [tmp ]
227+ self .zero [tmp ] = zero1 [tmp ]
228+
229+ shape = [- 1 ] + [1 ] * (len (shape ) - 1 )
230+ self .scale = self .scale .reshape (shape )
231+ self .zero = self .zero .reshape (shape )
232+
144233 def quantize (self , x ):
145234 if self .ready ():
146235 return quantize (x , self .scale , self .zero , self .maxq )
0 commit comments