Skip to content

Commit d9a26a5

Browse files
authored
Merge branch 'main' into quantization
2 parents 86adb6d + aa8e198 commit d9a26a5

File tree

4 files changed

+5
-10
lines changed

4 files changed

+5
-10
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
4141
from torch.quantization.qconfig import QConfig
4242

43+
4344
from torch.distributions.normal import Normal
4445
from torch.distributions.uniform import Uniform
4546

@@ -419,6 +420,7 @@ def forward(self, x, return_kl=True):
419420
return out
420421

421422

423+
422424
class Conv3dFlipout(BaseVariationalLayer_):
423425
def __init__(self,
424426
in_channels,

bayesian_torch/layers/flipout_layers/linear_flipout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,4 @@ def forward(self, x, return_kl=True):
195195
if return_kl:
196196
return out, kl
197197
return out
198+

bayesian_torch/layers/variational_layers/conv_variational.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def forward(self, input, return_kl=True):
374374
else:
375375
kl = kl_weight
376376
return out, kl
377-
377+
378378
return out
379379

380380

@@ -973,12 +973,3 @@ def forward(self, input, return_kl=True):
973973

974974
return out
975975

976-
if __name__=="__main__":
977-
m = Conv2dReparameterization(3,3,3)
978-
m.eval()
979-
m.prepare()
980-
m.qconfig = torch.quantization.get_default_qconfig("fbgemm")
981-
mp = torch.quantization.prepare(m)
982-
input = torch.randn(3,3,4,4)
983-
mp(input)
984-
mq = torch.quantization.convert(mp)

bayesian_torch/layers/variational_layers/linear_variational.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def forward(self, input, return_kl=True):
162162
tmp_result = sigma_weight * eps_weight
163163
weight = self.mu_weight + tmp_result
164164

165+
165166
if return_kl:
166167
kl_weight = self.kl_div(self.mu_weight, sigma_weight,
167168
self.prior_weight_mu, self.prior_weight_sigma)

0 commit comments

Comments
 (0)