Skip to content

Commit e74c1ca

Browse files
committed
feat: Add support for 4-bit quantization
1 parent 912e11e commit e74c1ca

File tree

8 files changed

+107
-40
lines changed

8 files changed

+107
-40
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ data/
33
runs/
44
runs_opt/
55
backup/
6+
batchtest/
67
*.obj
78
*.dll
89
*.exp
@@ -16,6 +17,8 @@ backup/
1617
*.lst
1718
*.bkp
1819
*.pdf
20+
*.log
21+
*.elog
1922
# python cache
2023
__pycache__/
2124
venv/

BitNetMCU.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,21 @@ class FCMNIST(nn.Module):
1414
@cpldcpu 2024-March-24
1515
1616
"""
17-
def __init__(self,network_width1=64,network_width2=64,network_width3=64,QuantType='Binary',WScale='PerTensor',NormType='RMS'):
17+
def __init__(self,network_width1=64,network_width2=64,network_width3=64,QuantType='Binary',WScale='PerTensor',NormType='RMS', quantscale=0.25):
1818
super(FCMNIST, self).__init__()
1919

2020
self.network_width1 = network_width1
2121
self.network_width2 = network_width2
2222
self.network_width3 = network_width3
23+
self.quantscale = quantscale
2324

24-
self.fc1 = BitLinear(1* 1 *16 *16, network_width1,QuantType=QuantType,NormType=NormType, WScale=WScale)
25-
self.fc2 = BitLinear(network_width1, network_width2,QuantType=QuantType,NormType=NormType, WScale=WScale)
25+
self.fc1 = BitLinear(1* 1 *16 *16, network_width1,QuantType=QuantType,NormType=NormType, WScale=WScale, quantscale=quantscale)
26+
self.fc2 = BitLinear(network_width1, network_width2,QuantType=QuantType,NormType=NormType, WScale=WScale , quantscale=quantscale )
2627
if network_width3>0:
27-
self.fc3 = BitLinear(network_width2, network_width3,QuantType=QuantType,NormType=NormType, WScale=WScale)
28-
self.fcl = BitLinear(network_width3, 10,QuantType=QuantType,NormType=NormType, WScale=WScale)
28+
self.fc3 = BitLinear(network_width2, network_width3,QuantType=QuantType,NormType=NormType, WScale=WScale , quantscale=quantscale)
29+
self.fcl = BitLinear(network_width3, 10,QuantType=QuantType,NormType=NormType, WScale=WScale , quantscale=quantscale)
2930
else:
30-
self.fcl = BitLinear(network_width2, 10,QuantType=QuantType,NormType=NormType, WScale=WScale)
31+
self.fcl = BitLinear(network_width2, 10,QuantType=QuantType,NormType=NormType, WScale=WScale , quantscale=quantscale)
3132

3233
# self.dropout = nn.Dropout(0.10)
3334

@@ -64,18 +65,23 @@ class BitLinear(nn.Linear):
6465
- PerTensor : The weight scaling is calculated per Tensor
6566
- PerOutput : The weight scaling is calculated per Output
6667
68+
quantcale
69+
- scalar : The scale factor for the weight quantization, the default of 0.25
70+
biases the stddev of the weights toward 25% of the maximum scale
71+
6772
Implementation based on:
6873
https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
6974
7075
This is not optimized for speed or efficiency...
7176
7277
@cpldcpu 2024-March-24
7378
"""
74-
def __init__(self, in_features, out_features, bias=False, QuantType='Binary', WScale='PerTensor', NormType='RMS'):
79+
def __init__(self, in_features, out_features, bias=False, QuantType='Binary', WScale='PerTensor', NormType='RMS', quantscale=0.25):
7580
super(BitLinear, self).__init__(in_features, out_features, bias=False)
7681
self.QuantType = QuantType
7782
self.NormType = NormType
7883
self.WScale = WScale
84+
self.quantscale = quantscale
7985

8086
# flat init - does not help so keep default
8187
# fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
@@ -132,12 +138,6 @@ def weight_quant(self, w):
132138
if self.QuantType == 'Ternary': # 1.58bits
133139
scale = 1.0 / mag
134140
u = (w * scale).round().clamp_(-1, 1) / scale
135-
elif self.QuantType == 'Ternary06': # 1 bit
136-
scale = 0.6 / mag
137-
u = (w * scale).round().clamp_(-1, 1) / scale
138-
elif self.QuantType == 'Ternary4': # 1 bit
139-
scale = 4 / mag
140-
u = (w * scale).round().clamp_(-1, 1) / scale
141141
elif self.QuantType == 'Binary': # 1 bit
142142
scale = mag
143143
e = w.mean()
@@ -146,27 +146,24 @@ def weight_quant(self, w):
146146
scale = mag
147147
# e = w.mean()
148148
u = w.sign() * scale
149-
elif self.QuantType == 'BinarySymHS': # 1 bit
150-
scale = mag
151-
u = w.sign() * scale * 0.5
152-
elif self.QuantType == 'BinarySymDS': # 1 bit
153-
scale = mag
154-
u = w.sign() * scale * 2.0
155149
elif self.QuantType == '2bitsym':
156150
scale = 1.0 / mag # 2 worst, 1 better, 1.5 almost as bad as 2
157151
u = ((w * scale - 0.5).round().clamp_(-2, 1) + 0.5) / scale
152+
elif self.QuantType == '4bit': # 4 bit in one-complement encoding for inference with multiplication
153+
scale = self.quantscale * 8.0 / mag # 2.0 for tensor, 6.5 for output
154+
u = ((w * scale).round().clamp_(-8, 7)) / scale
158155
elif self.QuantType == '4bitsym':
159-
scale = 2.0 / mag # 2.0 for tensor, 6.5 for output
156+
scale = self.quantscale * 8.0 / mag # 2.0 for tensor, 6.5 for output
160157
u = ((w * scale - 0.5).round().clamp_(-8, 7) + 0.5) / scale
161-
elif self.QuantType == 'FP130': # encoding (F1.3.0) : S * ( 2^E3 + 1) -> min 2^0 = 1, max 2^7 = 127
162-
scale = 16.0 / mag
158+
elif self.QuantType == 'FP130': # encoding (F1.3.0) : S * ( 2^E3 + 1) -> min 2^0 = 1, max 2^7 = 128
159+
scale = 128.0 * self.quantscale / mag
163160
e = ((w * scale).abs()).log2().floor().clamp_(0, 7)
164161
u = w.sign()*(e.exp2()) / scale
165162
elif self.QuantType == '5bitsym':
166-
scale = 4.0 / mag # 4.0 for tensor, 13 for output
163+
scale = 16.0 * self.quantscale / mag # 4.0 for tensor, 13 for output
167164
u = ((w * scale - 0.5).round().clamp_(-16, 15) + 0.5) / scale
168165
elif self.QuantType == '8bit': # -128 to 127
169-
scale = 32.0 / mag
166+
scale = 128.0 * self.quantscale / mag
170167
u = (w * scale).round().clamp_(-128, 127) / scale
171168
else:
172169
raise AssertionError(f"Invalid QuantType: {self.QuantType}. Expected one of: 'Binary', 'BinaryBalanced', '2bitsym', '4bitsym', '8bit'")
@@ -197,13 +194,15 @@ class QuantizedModel:
197194
This class represents a quantized model. It provides functionality to quantize a given model.
198195
"""
199196

200-
def __init__(self, model = None, force_quantization = None):
197+
def __init__(self, model = None, force_quantization = None, quantscale=0.25):
201198
self.quantized_model=None
202199
self.total_bits=0
203200
self.force_quantization = force_quantization
201+
self.quantscale = quantscale
204202

205203
if model is not None:
206204
self.quantized_model, _ = self.quantize(model)
205+
self.quantscale = model.quantscale
207206

208207
def totalbits(self):
209208
"""
@@ -263,21 +262,25 @@ def quantize(self,model):
263262
scale = 1.0 / mag # 2 worst, 1 better, 1.5 almost as bad as 2
264263
u = ((w * scale - 0.5).round().clamp_(-2, 1) + 0.5)
265264
bpw = 2
265+
elif QuantType == '4bit': # 4 bit in one-complement encoding for inference with multiplication
266+
scale = 8.0 * self.quantscale / mag # 2.0 for tensor, 6.5 for output
267+
u = ((w * scale).round().clamp_(-8, 7))
268+
bpw = 4
266269
elif QuantType == '4bitsym':
267-
scale = 2.0 / mag # 2.0 for tensor, 6.5 for output
270+
scale = 8.0 * self.quantscale / mag # 2.0 for tensor, 6.5 for output
268271
u = ((w * scale - 0.5).round().clamp_(-8, 7) + 0.5)
269272
bpw = 4
270273
elif QuantType == 'FP130':
271-
scale = 16.0 / mag
274+
scale = 128.0 * self.quantscale / mag
272275
e = ((w * scale ).abs()).log2().floor().clamp_(0, 7)
273276
u = w.sign()*(e.exp2() )
274277
bpw = 4
275278
elif QuantType == '5bitsym':
276-
scale = 4.0 / mag # 4.0 for tensor, 14 for output
279+
scale = 16.0 * self.quantscale / mag # 4.0 for tensor, 14 for output
277280
u = ((w * scale - 0.5).round().clamp_(-16, 15) + 0.5)
278281
bpw = 5
279282
elif QuantType == '8bit':
280-
scale = 32.0 / mag
283+
scale = 128.0 * self.quantscale / mag
281284
u = (w * scale).round().clamp_(-128, 127)
282285
bpw = 8
283286
elif QuantType == 'None':

BitNetMCU_inference.c

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ void processfclayer( int8_t *activations, const uint32_t *weights, int32_t bits
111111
weightChunk <<= 2;
112112
}
113113
}
114+
// Muliplier-less inference for RB32EC
115+
#if defined(__riscv) && !defined(__riscv_mul)
114116
} else if (bits_per_weight == 4 ) {
115117
for (uint32_t k = 0; k < n_input; k+=8) {
116118
uint32_t weightChunk = *weightidx++;
@@ -126,6 +128,30 @@ void processfclayer( int8_t *activations, const uint32_t *weights, int32_t bits
126128
weightChunk <<= 4;
127129
}
128130
}
131+
#else
132+
} else if (bits_per_weight == 4 ) {
133+
for (uint32_t k = 0; k < n_input; k+=8) {
134+
uint32_t weightChunk = *weightidx++;
135+
for (uint32_t j = 0; j < 8; j++) {
136+
int32_t in=*activations_idx++;
137+
if (in != 0) { // Skip zero activations to speed up inference in layers after first layer
138+
int32_t tmpsum = (weightChunk & 0x80000000) ? -in : in; // one complements sign (bit set equals negative)
139+
sum += tmpsum * ((weightChunk>>(32-4))&7); // sign*in*1
140+
}
141+
weightChunk <<= 4;
142+
}
143+
}
144+
#endif
145+
} else if (bits_per_weight == 8 + 4 ) { // 4 bit twos-complement
146+
for (uint32_t k = 0; k < n_input; k+=8) {
147+
int32_t weightChunk = *weightidx++;
148+
for (uint32_t j = 0; j < 8; j++) {
149+
int32_t in=*activations_idx++;
150+
int32_t weight = (weightChunk) >> (32-4); // extend sign, cut off lower bits
151+
sum += in*weight;
152+
weightChunk <<= 4;
153+
}
154+
}
129155
} else if (bits_per_weight == 16 + 4 ) { // 4 bit shift
130156
for (uint32_t k = 0; k < n_input; k+=8) {
131157
uint32_t weightChunk = *weightidx++;

docs/documentation.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,10 @@ By simplifying the model architecture and using a full-custom implementation, I
538538

539539
While this project focused on MNIST inference as a test case, I plan to apply this approach to other applications in the future.
540540

541-
# Addendum: FP1.3.0 Quantization
541+
# Addendum: Additional quantization schemes
542+
543+
544+
## FP1.3.0 Quantization
542545

543546
<div align="center">
544547
<img src="first_layer_weights_fp130.png" width="60%">
@@ -550,6 +553,20 @@ While this project focused on MNIST inference as a test case, I plan to apply th
550553

551554
TODO
552555

556+
```
557+
1ee: 00170483 lb s1,1(a4)
558+
1f2: 00035463 bgez t1,1fa <processfclayer+0x4a>
559+
1f6: 409004b3 neg s1,s1
560+
561+
1fa: 01c35313 srli t1,t1,0x1c
562+
1fe: 00737313 andi t1,t1,7
563+
202: 006494b3 sll s1,s1,t1
564+
565+
206: 00879313 slli t1,a5,0x8
566+
567+
20a: 9626 add a2,a2,s1
568+
```
569+
553570
# References
554571

555572
References and further reading:

exportquant.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import matplotlib.pyplot as plt
99
import argparse
1010
import yaml
11+
import seaborn as sns
1112

1213
# Export quantized model from saved checkpoint
1314
# cpldcpu 2024-04-14
@@ -80,6 +81,9 @@ def export_to_hfile(quantized_model, filename, runname):
8081
elif quantization_type == '4bitsym':
8182
encoded_weights = ((weights < 0).astype(data_type) << 3) | (np.floor(np.abs(weights))).astype(data_type) # use bitwise operations to encode the weights
8283
QuantID = 4
84+
elif quantization_type == '4bit':
85+
encoded_weights = np.floor(weights).astype(int) & 15 # twos complement encoding
86+
QuantID = 8 + 4
8387
elif quantization_type == 'FP130': # FP1.3.0 encoding (sign * 2^exp)
8488
encoded_weights = ((weights < 0).astype(data_type) << 3) | (np.floor(np.log2(np.abs(weights)))).astype(data_type)
8589
QuantID = 16 + 4
@@ -213,12 +217,14 @@ def plot_weight_histograms(quantized_model):
213217

214218
for layer_index, layer in enumerate(quantized_model.quantized_model):
215219
layer_weights = np.array(layer['quantized_weights'])
220+
bpw = layer['bpw']
216221

217222
flattened_weights = layer_weights.flatten()
218223

219224
ax = fig.add_subplot(len(quantized_model.quantized_model), 1, layer_index + 1)
220225

221-
ax.hist(flattened_weights, bins='auto')
226+
# ax.hist(flattened_weights, width=1, bins='auto')
227+
sns.histplot(flattened_weights, bins=2**bpw, ax=ax, kde=True)
222228
ax.set_title(f'Layer {layer_index+1} Weight Distribution')
223229

224230
plt.tight_layout()
@@ -266,7 +272,8 @@ def plot_weight_histograms(quantized_model):
266272
network_width3=hyperparameters["network_width3"],
267273
QuantType=hyperparameters["QuantType"],
268274
NormType=hyperparameters["NormType"],
269-
WScale=hyperparameters["WScale"]
275+
WScale=hyperparameters["WScale"],
276+
quantscale=hyperparameters["quantscale"]
270277
).to(device)
271278

272279
print('Loading model...')
@@ -292,7 +299,7 @@ def plot_weight_histograms(quantized_model):
292299

293300
print('Quantizing model...')
294301
# Quantize the model
295-
quantized_model = QuantizedModel(model)
302+
quantized_model = QuantizedModel(model, quantscale=hyperparameters["quantscale"])
296303

297304
# Print statistics
298305
print_stats(quantized_model)
Binary file not shown.

training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def train_model(model, device, hyperparameters, train_data, test_data):
189189
network_width3=hyperparameters["network_width3"],
190190
QuantType=hyperparameters["QuantType"],
191191
NormType=hyperparameters["NormType"],
192-
WScale=hyperparameters["WScale"]
192+
WScale=hyperparameters["WScale"],
193+
quantscale=hyperparameters["quantscale"]
193194
).to(device)
194195

195196
print('training...')

trainingparameters.yaml

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
num_epochs: 60
2-
QuantType: '4bitsym' # 'Ternary', 'Binary', 'BinaryBalanced', '2bitsym', '4bitsym', '8bit', 'None", 'FP130'
1+
# Quantization settings
2+
QuantType: '4bitsym' # 'Ternary', 'Binary', 'BinaryBalanced', '2bitsym', '4bit', '4bitsym', '8bit', 'None", 'FP130'
33
BPW : 4
44
NormType: 'RMS' # 'RMS', 'Lin', 'BatchNorm'
5-
WScale: 'PerTensor' # 'PerTensor', 'PerOutput', 'PerOutputLog2'
5+
WScale: 'PerTensor' # 'PerTensor', 'PerOutput'
6+
quantscale: 0.25 # How to scale the stddev of each tensor relative to the max value
7+
8+
# Learning parameters
69
batch_size: 128
10+
num_epochs: 60
711
scheduler: "Cosine" # "StepLR", "Cosine"
812
learning_rate: 0.001
913
lr_decay: 0.1 # lr_decay and step size are not used with cosine scheduler
1014
step_size: 10
11-
network_width1: 64
12-
network_width2: 64
13-
network_width3: 64
15+
16+
# Data augmentation
1417
augmentation: True
1518
rotation1: 10 # rotation1 and rotation2 are used for data augmentation
1619
rotation2: 10
20+
21+
# Model parameters
22+
network_width1: 64
23+
network_width2: 64
24+
network_width3: 64
25+
26+
# name
1727
runtag: "opt_" # runtag is prefix for runname

0 commit comments

Comments
 (0)