Skip to content

Commit 97ebe28

Browse files
include kl_loss() function in Convolutional flipout layers,
to compute kl when 'return_kl' flag is set to False. Fix for issue#12. Signed-off-by: Ranganath Krishnan <[email protected]>
1 parent 1e85c50 commit 97ebe28

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ def init_parameters(self):
152152
self.prior_bias_mu.data.fill_(self.prior_mean)
153153
self.prior_bias_sigma.data.fill_(self.prior_variance)
154154

155+
def kl_loss(self):
156+
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
157+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
158+
if self.bias:
159+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
160+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
161+
return kl
162+
155163
def forward(self, x, return_kl=True):
156164

157165
if self.dnn_to_bnn_flag:
@@ -311,6 +319,14 @@ def init_parameters(self):
311319
self.prior_bias_mu.data.fill_(self.prior_mean)
312320
self.prior_bias_sigma.data.fill_(self.prior_variance)
313321

322+
def kl_loss(self):
323+
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
324+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
325+
if self.bias:
326+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
327+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
328+
return kl
329+
314330
def forward(self, x, return_kl=True):
315331

316332
if self.dnn_to_bnn_flag:
@@ -469,6 +485,14 @@ def init_parameters(self):
469485
self.prior_bias_mu.data.fill_(self.prior_mean)
470486
self.prior_bias_sigma.data.fill_(self.prior_variance)
471487

488+
def kl_loss(self):
489+
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
490+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
491+
if self.bias:
492+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
493+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
494+
return kl
495+
472496
def forward(self, x, return_kl=True):
473497

474498
if self.dnn_to_bnn_flag:
@@ -624,6 +648,14 @@ def init_parameters(self):
624648
self.prior_bias_mu.data.fill_(self.prior_mean)
625649
self.prior_bias_sigma.data.fill_(self.prior_variance)
626650

651+
def kl_loss(self):
652+
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
653+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
654+
if self.bias:
655+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
656+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
657+
return kl
658+
627659
def forward(self, x, return_kl=True):
628660

629661
if self.dnn_to_bnn_flag:
@@ -784,6 +816,14 @@ def init_parameters(self):
784816
self.prior_bias_mu.data.fill_(self.prior_mean)
785817
self.prior_bias_sigma.data.fill_(self.prior_variance)
786818

819+
def kl_loss(self):
820+
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
821+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
822+
if self.bias:
823+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
824+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
825+
return kl
826+
787827
def forward(self, x, return_kl=True):
788828

789829
if self.dnn_to_bnn_flag:
@@ -944,6 +984,14 @@ def init_parameters(self):
944984
self.prior_bias_mu.data.fill_(self.prior_mean)
945985
self.prior_bias_sigma.data.fill_(self.prior_variance)
946986

987+
def kl_loss(self):
988+
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
989+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
990+
if self.bias:
991+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
992+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
993+
return kl
994+
947995
def forward(self, x, return_kl=True):
948996

949997
if self.dnn_to_bnn_flag:

0 commit comments

Comments
 (0)