Skip to content

Commit 51bab43

Browse files
committed
quantization prepare function
1 parent 3360bcf commit 51bab43

File tree

2 files changed

+34
-246
lines changed

2 files changed

+34
-246
lines changed

bayesian_torch/layers/variational_layers/conv_variational.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,15 @@ def __init__(self,
295295
self.register_buffer('prior_bias_sigma', None, persistent=False)
296296

297297
self.init_parameters()
298+
self.quant_prepare=False
299+
300+
def prepare(self):
301+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
302+
QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8))) for _ in range(5)])
303+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
304+
QConfig(activation=HistogramObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
305+
self.dequant = torch.quantization.DeQuantStub()
306+
self.quant_prepare=True
298307

299308
def init_parameters(self):
300309
self.prior_weight_mu.fill_(self.prior_mean)
@@ -325,7 +334,8 @@ def forward(self, input, return_kl=True):
325334

326335
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
327336
eps_kernel = self.eps_kernel.data.normal_()
328-
weight = self.mu_kernel + (sigma_weight * eps_kernel)
337+
tmp_result = sigma_weight * eps_kernel
338+
weight = mu_kernel + tmp_result
329339

330340
if return_kl:
331341
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
@@ -342,6 +352,20 @@ def forward(self, input, return_kl=True):
342352

343353
out = F.conv2d(input, weight, bias, self.stride, self.padding,
344354
self.dilation, self.groups)
355+
356+
if self.quant_prepare:
357+
# quint8 quantstub
358+
input = self.quint_quant[0](input) # input
359+
out = self.quint_quant[1](out) # output
360+
361+
# qint8 quantstub
362+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
363+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
364+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
365+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
366+
weight = self.qint_quant[4](weight) # add activatation
367+
368+
345369
if return_kl:
346370
if self.bias:
347371
kl = kl_weight + kl_bias
@@ -946,3 +970,12 @@ def forward(self, input, return_kl=True):
946970
return out, kl
947971

948972
return out
973+
974+
if __name__=="__main__":
975+
m = Conv2dReparameterization(3,3,3)
976+
m.eval()
977+
m.qconfig = torch.quantization.get_default_qconfig("fbgemm")
978+
mp = torch.quantization.prepare(m)
979+
input = torch.randn(3,3,4,4)
980+
mp(input)
981+
mq = torch.quantization.convert(mp)

bayesian_torch/layers/variational_layers/conv_variational2.py

Lines changed: 0 additions & 245 deletions
This file was deleted.

0 commit comments

Comments
 (0)