@@ -14,20 +14,21 @@ class FCMNIST(nn.Module):
1414 @cpldcpu 2024-March-24
1515
1616 """
17- def __init__ (self ,network_width1 = 64 ,network_width2 = 64 ,network_width3 = 64 ,QuantType = 'Binary' ,WScale = 'PerTensor' ,NormType = 'RMS' ):
17+ def __init__ (self ,network_width1 = 64 ,network_width2 = 64 ,network_width3 = 64 ,QuantType = 'Binary' ,WScale = 'PerTensor' ,NormType = 'RMS' , quantscale = 0.25 ):
1818 super (FCMNIST , self ).__init__ ()
1919
2020 self .network_width1 = network_width1
2121 self .network_width2 = network_width2
2222 self .network_width3 = network_width3
23+ self .quantscale = quantscale
2324
24- self .fc1 = BitLinear (1 * 1 * 16 * 16 , network_width1 ,QuantType = QuantType ,NormType = NormType , WScale = WScale )
25- self .fc2 = BitLinear (network_width1 , network_width2 ,QuantType = QuantType ,NormType = NormType , WScale = WScale )
25+ self .fc1 = BitLinear (1 * 1 * 16 * 16 , network_width1 ,QuantType = QuantType ,NormType = NormType , WScale = WScale , quantscale = quantscale )
26+ self .fc2 = BitLinear (network_width1 , network_width2 ,QuantType = QuantType ,NormType = NormType , WScale = WScale , quantscale = quantscale )
2627 if network_width3 > 0 :
27- self .fc3 = BitLinear (network_width2 , network_width3 ,QuantType = QuantType ,NormType = NormType , WScale = WScale )
28- self .fcl = BitLinear (network_width3 , 10 ,QuantType = QuantType ,NormType = NormType , WScale = WScale )
28+ self .fc3 = BitLinear (network_width2 , network_width3 ,QuantType = QuantType ,NormType = NormType , WScale = WScale , quantscale = quantscale )
29+ self .fcl = BitLinear (network_width3 , 10 ,QuantType = QuantType ,NormType = NormType , WScale = WScale , quantscale = quantscale )
2930 else :
30- self .fcl = BitLinear (network_width2 , 10 ,QuantType = QuantType ,NormType = NormType , WScale = WScale )
31+ self .fcl = BitLinear (network_width2 , 10 ,QuantType = QuantType ,NormType = NormType , WScale = WScale , quantscale = quantscale )
3132
3233 # self.dropout = nn.Dropout(0.10)
3334
@@ -64,18 +65,23 @@ class BitLinear(nn.Linear):
6465 - PerTensor : The weight scaling is calculated per Tensor
6566 - PerOutput : The weight scaling is calculated per Output
6667
68+ quantcale
69+ - scalar : The scale factor for the weight quantization, the default of 0.25
70+ biases the stddev of the weights toward 25% of the maximum scale
71+
6772 Implementation based on:
6873 https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
6974
7075 This is not optimized for speed or efficiency...
7176
7277 @cpldcpu 2024-March-24
7378 """
74- def __init__ (self , in_features , out_features , bias = False , QuantType = 'Binary' , WScale = 'PerTensor' , NormType = 'RMS' ):
79+ def __init__ (self , in_features , out_features , bias = False , QuantType = 'Binary' , WScale = 'PerTensor' , NormType = 'RMS' , quantscale = 0.25 ):
7580 super (BitLinear , self ).__init__ (in_features , out_features , bias = False )
7681 self .QuantType = QuantType
7782 self .NormType = NormType
7883 self .WScale = WScale
84+ self .quantscale = quantscale
7985
8086 # flat init - does not help so keep default
8187 # fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
@@ -132,12 +138,6 @@ def weight_quant(self, w):
132138 if self .QuantType == 'Ternary' : # 1.58bits
133139 scale = 1.0 / mag
134140 u = (w * scale ).round ().clamp_ (- 1 , 1 ) / scale
135- elif self .QuantType == 'Ternary06' : # 1 bit
136- scale = 0.6 / mag
137- u = (w * scale ).round ().clamp_ (- 1 , 1 ) / scale
138- elif self .QuantType == 'Ternary4' : # 1 bit
139- scale = 4 / mag
140- u = (w * scale ).round ().clamp_ (- 1 , 1 ) / scale
141141 elif self .QuantType == 'Binary' : # 1 bit
142142 scale = mag
143143 e = w .mean ()
@@ -146,27 +146,24 @@ def weight_quant(self, w):
146146 scale = mag
147147 # e = w.mean()
148148 u = w .sign () * scale
149- elif self .QuantType == 'BinarySymHS' : # 1 bit
150- scale = mag
151- u = w .sign () * scale * 0.5
152- elif self .QuantType == 'BinarySymDS' : # 1 bit
153- scale = mag
154- u = w .sign () * scale * 2.0
155149 elif self .QuantType == '2bitsym' :
156150 scale = 1.0 / mag # 2 worst, 1 better, 1.5 almost as bad as 2
157151 u = ((w * scale - 0.5 ).round ().clamp_ (- 2 , 1 ) + 0.5 ) / scale
152+ elif self .QuantType == '4bit' : # 4 bit in one-complement encoding for inference with multiplication
153+ scale = self .quantscale * 8.0 / mag # 2.0 for tensor, 6.5 for output
154+ u = ((w * scale ).round ().clamp_ (- 8 , 7 )) / scale
158155 elif self .QuantType == '4bitsym' :
159- scale = 2 .0 / mag # 2.0 for tensor, 6.5 for output
156+ scale = self . quantscale * 8 .0 / mag # 2.0 for tensor, 6.5 for output
160157 u = ((w * scale - 0.5 ).round ().clamp_ (- 8 , 7 ) + 0.5 ) / scale
161- elif self .QuantType == 'FP130' : # encoding (F1.3.0) : S * ( 2^E3 + 1) -> min 2^0 = 1, max 2^7 = 127
162- scale = 16.0 / mag
158+ elif self .QuantType == 'FP130' : # encoding (F1.3.0) : S * ( 2^E3 + 1) -> min 2^0 = 1, max 2^7 = 128
159+ scale = 128.0 * self . quantscale / mag
163160 e = ((w * scale ).abs ()).log2 ().floor ().clamp_ (0 , 7 )
164161 u = w .sign ()* (e .exp2 ()) / scale
165162 elif self .QuantType == '5bitsym' :
166- scale = 4.0 / mag # 4.0 for tensor, 13 for output
163+ scale = 16.0 * self . quantscale / mag # 4.0 for tensor, 13 for output
167164 u = ((w * scale - 0.5 ).round ().clamp_ (- 16 , 15 ) + 0.5 ) / scale
168165 elif self .QuantType == '8bit' : # -128 to 127
169- scale = 32.0 / mag
166+ scale = 128.0 * self . quantscale / mag
170167 u = (w * scale ).round ().clamp_ (- 128 , 127 ) / scale
171168 else :
172169 raise AssertionError (f"Invalid QuantType: { self .QuantType } . Expected one of: 'Binary', 'BinaryBalanced', '2bitsym', '4bitsym', '8bit'" )
@@ -197,13 +194,15 @@ class QuantizedModel:
197194 This class represents a quantized model. It provides functionality to quantize a given model.
198195 """
199196
200- def __init__ (self , model = None , force_quantization = None ):
197+ def __init__ (self , model = None , force_quantization = None , quantscale = 0.25 ):
201198 self .quantized_model = None
202199 self .total_bits = 0
203200 self .force_quantization = force_quantization
201+ self .quantscale = quantscale
204202
205203 if model is not None :
206204 self .quantized_model , _ = self .quantize (model )
205+ self .quantscale = model .quantscale
207206
208207 def totalbits (self ):
209208 """
@@ -263,21 +262,25 @@ def quantize(self,model):
263262 scale = 1.0 / mag # 2 worst, 1 better, 1.5 almost as bad as 2
264263 u = ((w * scale - 0.5 ).round ().clamp_ (- 2 , 1 ) + 0.5 )
265264 bpw = 2
265+ elif QuantType == '4bit' : # 4 bit in one-complement encoding for inference with multiplication
266+ scale = 8.0 * self .quantscale / mag # 2.0 for tensor, 6.5 for output
267+ u = ((w * scale ).round ().clamp_ (- 8 , 7 ))
268+ bpw = 4
266269 elif QuantType == '4bitsym' :
267- scale = 2.0 / mag # 2.0 for tensor, 6.5 for output
270+ scale = 8.0 * self . quantscale / mag # 2.0 for tensor, 6.5 for output
268271 u = ((w * scale - 0.5 ).round ().clamp_ (- 8 , 7 ) + 0.5 )
269272 bpw = 4
270273 elif QuantType == 'FP130' :
271- scale = 16.0 / mag
274+ scale = 128.0 * self . quantscale / mag
272275 e = ((w * scale ).abs ()).log2 ().floor ().clamp_ (0 , 7 )
273276 u = w .sign ()* (e .exp2 () )
274277 bpw = 4
275278 elif QuantType == '5bitsym' :
276- scale = 4.0 / mag # 4.0 for tensor, 14 for output
279+ scale = 16.0 * self . quantscale / mag # 4.0 for tensor, 14 for output
277280 u = ((w * scale - 0.5 ).round ().clamp_ (- 16 , 15 ) + 0.5 )
278281 bpw = 5
279282 elif QuantType == '8bit' :
280- scale = 32.0 / mag
283+ scale = 128.0 * self . quantscale / mag
281284 u = (w * scale ).round ().clamp_ (- 128 , 127 )
282285 bpw = 8
283286 elif QuantType == 'None' :
0 commit comments