Skip to content

Commit 87488e2

Browse files
committed
qbnn performance test
1 parent b780aad commit 87488e2

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

bayesian_torch/layers/variational_layers/conv_variational.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from torch.nn import Parameter
4949
from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size
5050
import math
51-
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
51+
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
5252
from torch.quantization.qconfig import QConfig
5353

5454
__all__ = [
@@ -301,9 +301,9 @@ def __init__(self,
301301

302302
def prepare(self):
303303
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
304-
QConfig(weight=HistogramObserver.with_args(dtype=torch.qint8), activation=HistogramObserver.with_args(dtype=torch.qint8))) for _ in range(5)])
304+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
305305
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
306-
QConfig(weight=HistogramObserver.with_args(dtype=torch.quint8), activation=HistogramObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
306+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
307307
self.dequant = torch.quantization.DeQuantStub()
308308
self.quant_prepare=True
309309

bayesian_torch/layers/variational_layers/linear_variational.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from torch.nn import Module, Parameter
4848
from ..base_variational_layer import BaseVariationalLayer_
4949
import math
50+
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
51+
from torch.quantization.qconfig import QConfig
5052

5153

5254
class LinearReparameterization(BaseVariationalLayer_):
@@ -120,9 +122,9 @@ def __init__(self,
120122

121123
def prepare(self):
122124
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
123-
QConfig(weight=HistogramObserver.with_args(dtype=torch.qint8), activation=HistogramObserver.with_args(dtype=torch.qint8))) for _ in range(5)])
125+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
124126
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
125-
QConfig(weight=HistogramObserver.with_args(dtype=torch.quint8), activation=HistogramObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
127+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
126128
self.dequant = torch.quantization.DeQuantStub()
127129
self.quant_prepare=True
128130

@@ -157,7 +159,7 @@ def forward(self, input, return_kl=True):
157159
return_kl = False
158160
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
159161
eps_weight = self.eps_weight.data.normal_()
160-
tmp_result = sigma_weight * eps_kernel
162+
tmp_result = sigma_weight * eps_weight
161163
weight = self.mu_weight + tmp_result
162164

163165
if return_kl:

bayesian_torch/layers/variational_layers/quantize_linear_variational.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def quantize(self):
118118
delattr(self, "mu_weight")
119119
delattr(self, "rho_weight")
120120

121-
self.quantized_mu_bias = Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False)
122-
self.quantized_sigma_bias = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
121+
self.quantized_mu_bias = self.mu_bias#Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False)
122+
self.quantized_sigma_bias = torch.log1p(torch.exp(self.rho_bias))#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
123123
delattr(self, "mu_bias")
124124
delattr(self, "rho_bias")
125125

@@ -171,7 +171,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
171171

172172
if self.quant_dict is not None:
173173
eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
174-
weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
174+
weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point'])
175175
weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'])
176176
bias = None
177177

bayesian_torch/models/bnn_to_qbnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def bnn_to_qbnn(m, fuse_conv_bn=False):
200200
if m._modules[name]._modules:
201201
if "Conv" in m._modules[name].__class__.__name__:
202202
setattr(m, name, qbnn_conv_layer(m._modules[name]))
203+
elif "Linear" in m._modules[name].__class__.__name__:
204+
setattr(m, name, qbnn_linear_layer(m._modules[name]))
203205
else:
204206
bnn_to_qbnn(m._modules[name], fuse_conv_bn=fuse_conv_bn)
205207
elif "Linear" in m._modules[name].__class__.__name__:

0 commit comments

Comments
 (0)