Skip to content

Commit 9b5a9dc

Browse files
committed
quantization module prototype
1 parent b4ce3f5 commit 9b5a9dc

File tree

8 files changed

+272
-13
lines changed

8 files changed

+272
-13
lines changed

bayesian_torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from bayesian_torch import quantization as quantization
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
## bayesian_torch.quantization.prepare
2-
## bayesian_torch.quantization.convert
2+
## bayesian_torch.quantization.convert
3+
from .quantize import *
Lines changed: 161 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,163 @@
1-
"""
2-
define prepare and convert function
3-
"""
1+
# Copyright (C) 2021 Intel Labs
2+
#
3+
# BSD-3-Clause License
4+
#
5+
# Redistribution and use in source and binary forms, with or without modification,
6+
# are permitted provided that the following conditions are met:
7+
# 1. Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# 2. Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# 3. Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
18+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
20+
# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
21+
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
22+
# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
25+
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
26+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
#
28+
# Define prepare and convert function
29+
#
430

5-
def prepare():
6-
return
31+
import torch
32+
import torch.nn as nn
33+
from bayesian_torch.models.bayesian.resnet_variational_large import (
34+
BasicBlock,
35+
Bottleneck,
36+
ResNet,
37+
)
38+
from typing import Any, List, Optional, Type, Union
39+
from torch import Tensor
40+
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
41+
# import copy
742

8-
def convert():
9-
return
43+
__all__ = [
44+
"prepare",
45+
"convert",
46+
]
47+
48+
class QuantizableBasicBlock(BasicBlock):
49+
def __init__(self, *args: Any, **kwargs: Any) -> None:
50+
super().__init__(*args, **kwargs)
51+
self.add_relu = torch.nn.quantized.FloatFunctional()
52+
53+
def forward(self, x: Tensor) -> Tensor:
54+
identity = x
55+
56+
out = self.conv1(x)
57+
out = self.bn1(out)
58+
out = self.relu(out)
59+
60+
out = self.conv2(out)
61+
out = self.bn2(out)
62+
63+
if self.downsample is not None:
64+
identity = self.downsample(x)
65+
66+
out = self.add_relu.add_relu(out, identity)
67+
68+
return out
69+
70+
71+
class QuantizableBottleneck(Bottleneck):
72+
def __init__(self, *args: Any, **kwargs: Any) -> None:
73+
super().__init__(*args, **kwargs)
74+
self.skip_add_relu = nn.quantized.FloatFunctional()
75+
self.relu1 = nn.ReLU(inplace=False)
76+
self.relu2 = nn.ReLU(inplace=False)
77+
78+
def forward(self, x: Tensor) -> Tensor:
79+
identity = x
80+
out = self.conv1(x)
81+
out = self.bn1(out)
82+
out = self.relu1(out)
83+
out = self.conv2(out)
84+
out = self.bn2(out)
85+
out = self.relu2(out)
86+
87+
out = self.conv3(out)
88+
out = self.bn3(out)
89+
90+
if self.downsample is not None:
91+
identity = self.downsample(x)
92+
out = self.skip_add_relu.add_relu(out, identity)
93+
94+
return out
95+
96+
97+
class QuantizableResNet(ResNet):
98+
def __init__(self, *args: Any, **kwargs: Any) -> None:
99+
super().__init__(*args, **kwargs)
100+
101+
self.quant = torch.ao.quantization.QuantStub()
102+
self.dequant = torch.ao.quantization.DeQuantStub()
103+
104+
def forward(self, x: Tensor) -> Tensor:
105+
x = self.quant(x)
106+
107+
x= self.conv1(x)
108+
x = self.bn1(x)
109+
x = self.relu(x)
110+
x = self.maxpool(x)
111+
112+
for layer in self.layer1:
113+
x=layer(x)
114+
115+
for layer in self.layer2:
116+
x = layer(x)
117+
118+
for layer in self.layer3:
119+
x = layer(x)
120+
121+
for layer in self.layer4:
122+
x = layer(x)
123+
124+
x = self.avgpool(x)
125+
x = torch.flatten(x, 1)
126+
x = self.fc(x)
127+
128+
129+
# x = self.dequant(x)
130+
return x
131+
132+
133+
134+
def enable_prepare(m):
135+
for name, value in list(m._modules.items()):
136+
if m._modules[name]._modules:
137+
enable_prepare(m._modules[name])
138+
elif "Reparameterization" in m._modules[name].__class__.__name__ or "Flipout" in m._modules[name].__class__.__name__:
139+
prepare = getattr(m._modules[name], "prepare", None)
140+
if callable(prepare):
141+
m._modules[name].prepare()
142+
m._modules[name].dnn_to_bnn_flag=True
143+
144+
145+
def prepare(model):
146+
"""
147+
1. construct quantizable model
148+
2. traverse the model to enable the prepare function in each layer
149+
3. run torch.quantize.prepare()
150+
"""
151+
qmodel = QuantizableResNet(QuantizableBottleneck, [3, 4, 6, 3])
152+
qmodel.load_state_dict(model.state_dict())
153+
qmodel.eval()
154+
enable_prepare(qmodel)
155+
qmodel.qconfig = torch.quantization.get_default_qconfig("fbgemm")
156+
qmodel = torch.quantization.prepare(qmodel)
157+
158+
return qmodel
159+
160+
def convert(model):
161+
qmodel = torch.quantization.convert(model) # torch layers
162+
bnn_to_qbnn(qmodel) # bayesian layers
163+
return qmodel
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# import torch
2+
# import bayesian_torch
3+
# from bayesian_torch.ao.quantization import prepare, convert
4+
# import bayesian_torch.models.bayesian.resnet_variational_large as resnet
5+
# from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
6+
7+
# model = resnet.__dict__['resnet50']()
8+
9+
# input = torch.randn(1,3,224,224)
10+
# mp = prepare(model)
11+
# mp(input) # haven't replaced the batchnorm layer
12+
# qmodel = torch.quantization.convert(mp)
13+
# bnn_to_qbnn(qmodel)
14+
15+
16+
import torch
17+
import bayesian_torch
18+
import bayesian_torch.models.bayesian.resnet_variational_large as resnet
19+
20+
m = resnet.__dict__['resnet50']()
21+
# alternative way to construct a bnn model
22+
# from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn
23+
# m = torchvision.models.resnet50(weights="IMAGENET1K_V1")
24+
# dnn_to_bnn(m)
25+
26+
27+
28+
mp = bayesian_torch.quantization.prepare(m)
29+
input = torch.randn(1,3,224,224)
30+
mp(input) # calibration
31+
mq = bayesian_torch.quantization.convert(mp)
32+
33+
34+

bayesian_torch/layers/variational_layers/quantize_conv_variational.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self,
9393
self.bn_eps = None
9494

9595
self.is_dequant = False
96+
self.quant_dict = None
9697

9798
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
9899
""" An implementation for symmetric quantization
@@ -237,7 +238,26 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
237238
if self.dnn_to_bnn_flag:
238239
return_kl = False
239240

240-
if not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
241+
if self.quant_dict is not None:
242+
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
243+
weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
244+
weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'])
245+
bias = None
246+
247+
## DO NOT QUANTIZE BIAS!!!
248+
if self.bias:
249+
if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion
250+
bias = self.quantized_mu_bias
251+
else: # original case
252+
bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_())
253+
254+
if input.dtype!=torch.quint8: # check if input has been quantized
255+
input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format
256+
257+
out = torch.nn.quantized.functional.conv1d(input, weight, bias, self.stride, self.padding,
258+
self.dilation, self.groups, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32
259+
260+
elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
241261
if not self.is_dequant:
242262
self.dequantize()
243263
self.is_dequant = True
@@ -323,6 +343,7 @@ def __init__(self,
323343
self.bn_eps = None
324344

325345
self.is_dequant = False
346+
self.quant_dict = None
326347

327348
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
328349
""" An implementation for symmetric quantization
@@ -419,6 +440,10 @@ def quantize(self):
419440
delattr(self, "bn_running_var")
420441
delattr(self, "bn_eps")
421442

443+
delattr(self, "qint_quant")
444+
delattr(self, "quint_quant")
445+
delattr(self, "dequant")
446+
422447
def dequantize(self): # Deprecated. Only for forward mode #1.
423448
self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight)
424449
self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight)
@@ -466,7 +491,26 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
466491
if self.dnn_to_bnn_flag:
467492
return_kl = False
468493

469-
if not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
494+
if self.quant_dict is not None:
495+
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
496+
weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
497+
weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'])
498+
bias = None
499+
500+
## DO NOT QUANTIZE BIAS!!!
501+
if self.bias:
502+
if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion
503+
bias = self.quantized_mu_bias
504+
else: # original case
505+
bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_())
506+
507+
if input.dtype!=torch.quint8: # check if input has been quantized
508+
input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format
509+
510+
out = torch.nn.quantized.functional.conv2d(input, weight, bias, self.stride, self.padding,
511+
self.dilation, self.groups, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32
512+
513+
elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
470514
if not self.is_dequant:
471515
self.dequantize()
472516
self.is_dequant = True
@@ -550,6 +594,7 @@ def __init__(self,
550594
self.bn_eps = None
551595

552596
self.is_dequant = False
597+
self.quant_dict = None
553598

554599
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
555600
""" An implementation for symmetric quantization
@@ -693,7 +738,26 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
693738
if self.dnn_to_bnn_flag:
694739
return_kl = False
695740

696-
if not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
741+
if self.quant_dict is not None:
742+
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
743+
weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
744+
weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'])
745+
bias = None
746+
747+
## DO NOT QUANTIZE BIAS!!!
748+
if self.bias:
749+
if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion
750+
bias = self.quantized_mu_bias
751+
else: # original case
752+
bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_())
753+
754+
if input.dtype!=torch.quint8: # check if input has been quantized
755+
input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format
756+
757+
out = torch.nn.quantized.functional.conv3d(input, weight, bias, self.stride, self.padding,
758+
self.dilation, self.groups, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32
759+
760+
elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only.
697761
if not self.is_dequant:
698762
self.dequantize()
699763
self.is_dequant = True

bayesian_torch/models/bayesian/resnet_variational_large.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from bayesian_torch.layers import BatchNorm2dLayer
1515

1616
__all__ = [
17-
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
17+
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'BasicBlock', 'Bottleneck'
1818
]
1919

2020
prior_mu = 0.0
@@ -200,7 +200,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
200200
posterior_mu_init=posterior_mu_init,
201201
posterior_rho_init=posterior_rho_init,
202202
bias=False),
203-
BatchNorm2dLayer(planes * block.expansion),
203+
nn.BatchNorm2d(planes * block.expansion),
204204
)
205205

206206
layers = []
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .quantize import *
2+
3+
# __all__ = ['prepare', 'convert']
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from bayesian_torch.ao.quantization.quantize import prepare
2+
from bayesian_torch.ao.quantization.quantize import convert

0 commit comments

Comments
 (0)