Skip to content

Commit b780aad

Browse files
committed
qbnn example
1 parent 9b0118f commit b780aad

File tree

4 files changed

+72
-11
lines changed

4 files changed

+72
-11
lines changed

bayesian_torch/examples/main_bayesian_imagenet_bnn2qbnn.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import bayesian_torch.models.bayesian.resnet_variational_large as resnet
1717
import numpy as np
1818
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
19-
# import bayesian_torch.models.bayesian.quantized_resnet_variational_large as qresnet
20-
import bayesian_torch.models.bayesian.quantized_resnet_flipout_large as qresnet
19+
import bayesian_torch.models.bayesian.quantized_resnet_variational_large as qresnet
20+
# import bayesian_torch.models.bayesian.quantized_resnet_flipout_large as qresnet
2121

2222
torch.cuda.is_available = lambda : False
2323
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
@@ -262,9 +262,16 @@ def main():
262262
model.load_state_dict(checkpoint["state_dict"])
263263
model.module = model.module.cpu()
264264

265-
bnn_to_qbnn(model, fuse_conv_bn=False) # only replaces linear and conv layers
265+
mp = bayesian_torch.quantization.prepare(model)
266+
evaluate(args, mp, val_loader) # calibration
267+
qmodel = bayesian_torch.quantization.convert(mp)
268+
evaluate(args, qmodel, val_loader)
269+
270+
266271

267-
model = model.cpu()
272+
# bnn_to_qbnn(model, fuse_conv_bn=False) # only replaces linear and conv layers
273+
274+
# model = model.cpu()
268275

269276
# save weights
270277
# save_checkpoint(
@@ -278,16 +285,16 @@ def main():
278285
# args.save_dir,
279286
# 'quantized_bayesian_q{}_imagenet.pth'.format(args.arch)))
280287

281-
qmodel = torch.nn.DataParallel(qresnet.__dict__['q'+args.arch](bias=False)) # set bias=True to make qconv has bias
282-
qmodel.module.quant_then_dequant(qmodel, fuse_conv_bn=False)
288+
# qmodel = torch.nn.DataParallel(qresnet.__dict__['q'+args.arch](bias=False)) # set bias=True to make qconv has bias
289+
# qmodel.module.quant_then_dequant(qmodel, fuse_conv_bn=False)
283290

284291
# load weights
285292
# checkpoint_file = args.save_dir + "/quantized_bayesian_q{}_imagenet.pth".format(args.arch)
286293
# checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
287294
# qmodel.load_state_dict(checkpoint["state_dict"])
288295

289-
qmodel.load_state_dict(model.state_dict())
290-
evaluate(args, qmodel, val_loader)
296+
# qmodel.load_state_dict(model.state_dict())
297+
# evaluate(args, qmodel, val_loader)
291298

292299
if __name__ == "__main__":
293300
main()

bayesian_torch/layers/variational_layers/linear_variational.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ def __init__(self,
116116
self.register_buffer('eps_bias', None, persistent=False)
117117

118118
self.init_parameters()
119+
self.quant_prepare=False
120+
121+
def prepare(self):
122+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
123+
QConfig(weight=HistogramObserver.with_args(dtype=torch.qint8), activation=HistogramObserver.with_args(dtype=torch.qint8))) for _ in range(5)])
124+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
125+
QConfig(weight=HistogramObserver.with_args(dtype=torch.quint8), activation=HistogramObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
126+
self.dequant = torch.quantization.DeQuantStub()
127+
self.quant_prepare=True
119128

120129
def init_parameters(self):
121130
self.prior_weight_mu.fill_(self.prior_mean)
@@ -147,8 +156,10 @@ def forward(self, input, return_kl=True):
147156
if self.dnn_to_bnn_flag:
148157
return_kl = False
149158
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
150-
weight = self.mu_weight + \
151-
(sigma_weight * self.eps_weight.data.normal_())
159+
eps_weight = self.eps_weight.data.normal_()
160+
tmp_result = sigma_weight * eps_kernel
161+
weight = self.mu_weight + tmp_result
162+
152163
if return_kl:
153164
kl_weight = self.kl_div(self.mu_weight, sigma_weight,
154165
self.prior_weight_mu, self.prior_weight_sigma)
@@ -162,6 +173,20 @@ def forward(self, input, return_kl=True):
162173
self.prior_bias_sigma)
163174

164175
out = F.linear(input, weight, bias)
176+
177+
if self.quant_prepare:
178+
# quint8 quantstub
179+
input = self.quint_quant[0](input) # input
180+
out = self.quint_quant[1](out) # output
181+
182+
# qint8 quantstub
183+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
184+
mu_weight = self.qint_quant[1](self.mu_weight) # weight
185+
eps_weight = self.qint_quant[2](eps_weight) # random variable
186+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
187+
weight = self.qint_quant[4](weight) # add activatation
188+
189+
165190
if return_kl:
166191
if self.mu_bias is not None:
167192
kl = kl_weight + kl_bias

bayesian_torch/layers/variational_layers/quantize_linear_variational.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
out_features)
5454

5555
self.is_dequant = False
56+
self.quant_dict = None
5657

5758
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
5859
""" An implementation for symmetric quantization
@@ -168,7 +169,26 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
168169
if self.dnn_to_bnn_flag:
169170
return_kl = False
170171

171-
if not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
172+
if self.quant_dict is not None:
173+
eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
174+
weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
175+
weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'])
176+
bias = None
177+
178+
## DO NOT QUANTIZE BIAS!!!
179+
if self.bias:
180+
if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion
181+
bias = self.quantized_mu_bias
182+
else: # original case
183+
bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_())
184+
185+
if input.dtype!=torch.quint8: # check if input has been quantized
186+
input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format
187+
188+
out = torch.nn.quantized.functional.linear(input, weight, bias, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32
189+
out = out.dequantize()
190+
191+
elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
172192
if not self.is_dequant:
173193
self.dequantize()
174194
self.is_dequant = True

bayesian_torch/models/bnn_to_qbnn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def qbnn_linear_layer(d):
101101
out_features=d.out_features,
102102
)
103103
qbnn_layer.__dict__.update(d.__dict__)
104+
105+
if d.quant_prepare:
106+
qbnn_layer.quant_dict = []
107+
for qstub in d.qint_quant:
108+
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
109+
qbnn_layer.quant_dict = qbnn_layer.quant_dict[2:]
110+
for qstub in d.quint_quant:
111+
qbnn_layer.quant_dict.append({'scale':qstub.scale.item(), 'zero_point':qstub.zero_point.item()})
112+
104113
qbnn_layer.quantize()
105114
if d.dnn_to_bnn_flag:
106115
qbnn_layer.dnn_to_bnn_flag = True

0 commit comments

Comments
 (0)