2121import torch
2222import torch .nn as nn
2323
24+ from tico .quantization .algorithm .fpi_gptq .util import iterate_GPTQ
2425
2526def quantize (x , scale , zero , maxq ):
2627 if maxq < 0 :
@@ -41,11 +42,12 @@ def configure(
4142 bits ,
4243 perchannel = False ,
4344 sym = True ,
44- mse = False ,
45+ mse = None ,
4546 norm = 2.4 ,
4647 grid = 100 ,
4748 maxshrink = 0.8 ,
4849 trits = False ,
50+ sensitivity = None
4951 ):
5052 self .maxq = torch .tensor (2 ** bits - 1 )
5153 self .perchannel = perchannel
@@ -54,6 +56,7 @@ def configure(
5456 self .norm = norm
5557 self .grid = grid
5658 self .maxshrink = maxshrink
59+ self .sensitivity = sensitivity
5760 if trits :
5861 self .maxq = torch .tensor (- 1 )
5962
@@ -99,7 +102,10 @@ def find_params(self, x, weight=False):
99102 else :
100103 self .zero = torch .round (- xmin / self .scale )
101104
102- if self .mse :
105+ if self .mse is not None and self .mse != "smse_for_gptq" :
106+ if self .mse == "smse" :
107+ self .maxshrink = 0.5
108+
103109 best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
104110 for i in range (int (self .maxshrink * self .grid )):
105111 p = 1 - i / self .grid
@@ -110,13 +116,19 @@ def find_params(self, x, weight=False):
110116 q = quantize (x , scale1 .unsqueeze (1 ), zero1 .unsqueeze (1 ), self .maxq )
111117 q -= x
112118 q .abs_ ()
113- q .pow_ (self .norm )
119+ if self .mse == "smse" :
120+ q = (q ** 2 ) * self .sensitivity .to (
121+ q .device
122+ ) # sensitivity weighted `mse`
123+ else :
124+ q .pow_ (self .norm )
114125 err = torch .sum (q , 1 )
115126 tmp = err < best
116127 if torch .any (tmp ):
117128 best [tmp ] = err [tmp ]
118129 self .scale [tmp ] = scale1 [tmp ]
119130 self .zero [tmp ] = zero1 [tmp ]
131+
120132 if not self .perchannel :
121133 if weight :
122134 tmp = shape [0 ]
@@ -140,7 +152,81 @@ def find_params(self, x, weight=False):
140152 if len (shape ) == 2 :
141153 self .scale = self .scale .unsqueeze (0 )
142154 self .zero = self .zero .unsqueeze (0 )
155+
156+ def update (self , x , Hinv , perm ):
157+ if self .mse is None or self .mse != "smse_for_gptq" :
158+ return
159+
160+ shape = x .shape
161+ if self .perchannel :
162+ x = x .flatten (1 )
163+ else :
164+ x = x .flatten ().unsqueeze (0 )
165+
166+ dev = x .device
167+ tmp = torch .zeros (x .shape [0 ], device = dev )
168+ xmin = torch .minimum (x .min (1 )[0 ], tmp )
169+ xmax = torch .maximum (x .max (1 )[0 ], tmp )
143170
171+ if self .sym :
172+ xmax = torch .maximum (torch .abs (xmin ), xmax )
173+ tmp = xmin < 0
174+ if torch .any (tmp ):
175+ xmin [tmp ] = - xmax [tmp ]
176+ tmp = (xmin == 0 ) & (xmax == 0 )
177+ xmin [tmp ] = - 1
178+ xmax [tmp ] = + 1
179+ if self .maxq < 0 :
180+ self .scale = xmax
181+ self .zero = xmin
182+ else :
183+ self .scale = (xmax - xmin ) / self .maxq
184+ if self .sym :
185+ self .zero = torch .full_like (self .scale , (self .maxq + 1 ) / 2 ) # type: ignore[arg-type]
186+ else :
187+ self .zero = torch .round (- xmin / self .scale )
188+
189+ self .maxshrink = 0.5
190+ sensitivity = None
191+ if self .sensitivity is not None :
192+ sensitivity = self .sensitivity .to (Hinv .dtype ).to (dev )
193+ if perm is not None :
194+ sensitivity = sensitivity [:, perm .to (dev )]
195+
196+ self .norm = 2
197+ num_of_iters = 15
198+ best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
199+ for i in range (int (self .maxshrink * self .grid )):
200+ p = 1 - i / self .grid
201+ xmin1 = p * xmin
202+ xmax1 = p * xmax
203+ scale1 = (xmax1 - xmin1 ) / self .maxq
204+ zero1 = torch .round (- xmin1 / scale1 ) if not self .sym else self .zero
205+ q , pre_q = iterate_GPTQ (
206+ scale1 .unsqueeze (1 ),
207+ zero1 .unsqueeze (1 ),
208+ self .maxq ,
209+ x ,
210+ Hinv ,
211+ max_num_of_iters = num_of_iters ,
212+ )
213+ err = torch .abs ((q - x ))
214+ if sensitivity is not None :
215+ err = ((q - pre_q ) / torch .diag (Hinv ))** 2
216+ else :
217+ err .pow_ (self .norm )
218+ err = err
219+ err = torch .sum (err , 1 )
220+ tmp = err < best
221+ if torch .any (tmp ):
222+ best [tmp ] = err [tmp ]
223+ self .scale [tmp ] = scale1 [tmp ]
224+ self .zero [tmp ] = zero1 [tmp ]
225+
226+ shape = [- 1 ] + [1 ] * (len (shape ) - 1 )
227+ self .scale = self .scale .reshape (shape )
228+ self .zero = self .zero .reshape (shape )
229+
144230 def quantize (self , x ):
145231 if self .ready ():
146232 return quantize (x , self .scale , self .zero , self .maxq )
0 commit comments