Skip to content

Commit b4ce3f5

Browse files
committed
finish quantization function
1 parent df09408 commit b4ce3f5

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

bayesian_torch/layers/variational_layers/conv_variational.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,9 @@ def __init__(self,
301301

302302
def prepare(self):
303303
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
304-
QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8))) for _ in range(5)])
304+
QConfig(weight=HistogramObserver.with_args(dtype=torch.qint8), activation=HistogramObserver.with_args(dtype=torch.qint8))) for _ in range(5)])
305305
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
306-
QConfig(activation=HistogramObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
306+
QConfig(weight=HistogramObserver.with_args(dtype=torch.quint8), activation=HistogramObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
307307
self.dequant = torch.quantization.DeQuantStub()
308308
self.quant_prepare=True
309309

@@ -337,7 +337,7 @@ def forward(self, input, return_kl=True):
337337
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
338338
eps_kernel = self.eps_kernel.data.normal_()
339339
tmp_result = sigma_weight * eps_kernel
340-
weight = mu_kernel + tmp_result
340+
weight = self.mu_kernel + tmp_result
341341

342342
if return_kl:
343343
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
@@ -976,6 +976,7 @@ def forward(self, input, return_kl=True):
976976
if __name__=="__main__":
977977
m = Conv2dReparameterization(3,3,3)
978978
m.eval()
979+
m.prepare()
979980
m.qconfig = torch.quantization.get_default_qconfig("fbgemm")
980981
mp = torch.quantization.prepare(m)
981982
input = torch.randn(3,3,4,4)

0 commit comments

Comments
 (0)