Skip to content

Commit 8fe84f9

Browse files
Merge pull request #11 from msubedar/dnn2bnn
Adding new feature to create Bayesian deep neural network model from predefined deterministic model architecture i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers (configured by 'const_bnn_prior_parameters' to set the prior parameters and sampling estimator - Reparameterization or Flipout method). This PR will enable seamless conversion of existing topology of larger models to Bayesian models for extending towards uncertainty-aware applications.
2 parents daa7292 + b98ff19 commit 8fe84f9

File tree

9 files changed

+968
-132
lines changed

9 files changed

+968
-132
lines changed

bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py

Lines changed: 522 additions & 0 deletions
Large diffs are not rendered by default.

bayesian_torch/layers/base_variational_layer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@
3434
class BaseVariationalLayer_(nn.Module):
3535
def __init__(self):
3636
super().__init__()
37+
self._dnn_to_bnn_flag = False
38+
39+
@property
40+
def dnn_to_bnn_flag(self):
41+
return self._dnn_to_bnn_flag
42+
43+
@dnn_to_bnn_flag.setter
44+
def dnn_to_bnn_flag(self, value):
45+
self._dnn_to_bnn_flag = value
3746

3847
def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
3948
"""

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def init_parameters(self):
154154

155155
def forward(self, x, return_kl=True):
156156

157+
if self.dnn_to_bnn_flag:
158+
return_kl = False
159+
157160
# linear outputs
158161
outputs = F.conv1d(x,
159162
weight=self.mu_kernel,
@@ -173,16 +176,18 @@ def forward(self, x, return_kl=True):
173176

174177
delta_kernel = (sigma_weight * eps_kernel)
175178

176-
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
177-
self.prior_weight_sigma)
179+
if return_kl:
180+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
181+
self.prior_weight_sigma)
178182

179183
bias = None
180184
if self.bias:
181185
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
182186
eps_bias = self.eps_bias.data.normal_()
183187
bias = (sigma_bias * eps_bias)
184-
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
185-
self.prior_bias_sigma)
188+
if return_kl:
189+
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
190+
self.prior_bias_sigma)
186191

187192
# perturbed feedforward
188193
perturbed_outputs = F.conv1d(x * sign_input,
@@ -308,6 +313,9 @@ def init_parameters(self):
308313

309314
def forward(self, x, return_kl=True):
310315

316+
if self.dnn_to_bnn_flag:
317+
return_kl = False
318+
311319
# linear outputs
312320
outputs = F.conv2d(x,
313321
weight=self.mu_kernel,
@@ -327,16 +335,18 @@ def forward(self, x, return_kl=True):
327335

328336
delta_kernel = (sigma_weight * eps_kernel)
329337

330-
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
331-
self.prior_weight_sigma)
338+
if return_kl:
339+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
340+
self.prior_weight_sigma)
332341

333342
bias = None
334343
if self.bias:
335344
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
336345
eps_bias = self.eps_bias.data.normal_()
337346
bias = (sigma_bias * eps_bias)
338-
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
339-
self.prior_bias_sigma)
347+
if return_kl:
348+
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
349+
self.prior_bias_sigma)
340350

341351
# perturbed feedforward
342352
perturbed_outputs = F.conv2d(x * sign_input,
@@ -347,7 +357,6 @@ def forward(self, x, return_kl=True):
347357
dilation=self.dilation,
348358
groups=self.groups) * sign_output
349359

350-
self.kl = kl
351360
# returning outputs + perturbations
352361
if return_kl:
353362
return outputs + perturbed_outputs, kl
@@ -462,6 +471,9 @@ def init_parameters(self):
462471

463472
def forward(self, x, return_kl=True):
464473

474+
if self.dnn_to_bnn_flag:
475+
return_kl = False
476+
465477
# linear outputs
466478
outputs = F.conv3d(x,
467479
weight=self.mu_kernel,
@@ -481,16 +493,18 @@ def forward(self, x, return_kl=True):
481493

482494
delta_kernel = (sigma_weight * eps_kernel)
483495

484-
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
485-
self.prior_weight_sigma)
496+
if return_kl:
497+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
498+
self.prior_weight_sigma)
486499

487500
bias = None
488501
if self.bias:
489502
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
490503
eps_bias = self.eps_bias.data.normal_()
491504
bias = (sigma_bias * eps_bias)
492-
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
493-
self.prior_bias_sigma)
505+
if return_kl:
506+
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
507+
self.prior_bias_sigma)
494508

495509
# perturbed feedforward
496510
perturbed_outputs = F.conv3d(x * sign_input,
@@ -612,6 +626,9 @@ def init_parameters(self):
612626

613627
def forward(self, x, return_kl=True):
614628

629+
if self.dnn_to_bnn_flag:
630+
return_kl = False
631+
615632
# linear outputs
616633
outputs = F.conv_transpose1d(x,
617634
weight=self.mu_kernel,
@@ -631,16 +648,18 @@ def forward(self, x, return_kl=True):
631648

632649
delta_kernel = (sigma_weight * eps_kernel)
633650

634-
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
635-
self.prior_weight_sigma)
651+
if return_kl:
652+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
653+
self.prior_weight_sigma)
636654

637655
bias = None
638656
if self.bias:
639657
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
640658
eps_bias = self.eps_bias.data.normal_()
641659
bias = (sigma_bias * eps_bias)
642-
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
643-
self.prior_bias_sigma)
660+
if return_kl:
661+
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
662+
self.prior_bias_sigma)
644663

645664
# perturbed feedforward
646665
perturbed_outputs = F.conv_transpose1d(
@@ -767,6 +786,9 @@ def init_parameters(self):
767786

768787
def forward(self, x, return_kl=True):
769788

789+
if self.dnn_to_bnn_flag:
790+
return_kl = False
791+
770792
# linear outputs
771793
outputs = F.conv_transpose2d(x,
772794
bias=self.mu_bias,
@@ -786,16 +808,18 @@ def forward(self, x, return_kl=True):
786808

787809
delta_kernel = (sigma_weight * eps_kernel)
788810

789-
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
790-
self.prior_weight_sigma)
811+
if return_kl:
812+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
813+
self.prior_weight_sigma)
791814

792815
bias = None
793816
if self.bias:
794817
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
795818
eps_bias = self.eps_bias.data.normal_()
796819
bias = (sigma_bias * eps_bias)
797-
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
798-
self.prior_bias_sigma)
820+
if return_kl:
821+
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
822+
self.prior_bias_sigma)
799823

800824
# perturbed feedforward
801825
perturbed_outputs = F.conv_transpose2d(
@@ -922,6 +946,9 @@ def init_parameters(self):
922946

923947
def forward(self, x, return_kl=True):
924948

949+
if self.dnn_to_bnn_flag:
950+
return_kl = False
951+
925952
# linear outputs
926953
outputs = F.conv_transpose3d(x,
927954
weight=self.mu_kernel,
@@ -941,8 +968,9 @@ def forward(self, x, return_kl=True):
941968

942969
delta_kernel = (sigma_weight * eps_kernel)
943970

944-
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
945-
self.prior_weight_sigma)
971+
if return_kl:
972+
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
973+
self.prior_weight_sigma)
946974

947975
bias = None
948976
if self.bias:

bayesian_torch/layers/flipout_layers/linear_flipout.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def __init__(self,
9090
torch.Tensor(out_features, in_features),
9191
persistent=False)
9292

93-
self.kl = 0
94-
9593
if bias:
9694
self.mu_bias = nn.Parameter(torch.Tensor(out_features))
9795
self.rho_bias = nn.Parameter(torch.Tensor(out_features))
@@ -125,21 +123,33 @@ def init_parameters(self):
125123
self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1)
126124
self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1)
127125

126+
def kl_loss(self):
127+
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
128+
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
129+
if self.mu_bias is not None:
130+
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
131+
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
132+
return kl
133+
128134
def forward(self, x, return_kl=True):
135+
if self.dnn_to_bnn_flag:
136+
return_kl = False
129137
# sampling delta_W
130138
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
131139
delta_weight = (sigma_weight * self.eps_weight.data.normal_())
132140

133141
# get kl divergence
134-
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
135-
self.prior_weight_sigma)
142+
if return_kl:
143+
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
144+
self.prior_weight_sigma)
136145

137146
bias = None
138147
if self.mu_bias is not None:
139148
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
140149
bias = (sigma_bias * self.eps_bias.data.normal_())
141-
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
142-
self.prior_bias_sigma)
150+
if return_kl:
151+
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
152+
self.prior_bias_sigma)
143153

144154
# linear outputs
145155
outputs = F.linear(x, self.mu_weight, self.mu_bias)
@@ -150,8 +160,6 @@ def forward(self, x, return_kl=True):
150160
perturbed_outputs = F.linear(x * sign_input, delta_weight,
151161
bias) * sign_output
152162

153-
self.kl = kl
154-
155163
# returning outputs + perturbations
156164
if return_kl:
157165
return outputs + perturbed_outputs, kl

bayesian_torch/layers/flipout_layers/rnn_flipout.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,16 @@ def __init__(self,
9494
out_features=out_features * 4,
9595
bias=bias)
9696

97+
def kl_loss(self):
98+
kl_i = self.ih.kl_loss()
99+
kl_h = self.hh.kl_loss()
100+
return kl_i + kl_h
101+
97102
def forward(self, X, hidden_states=None, return_kl=True):
98103

104+
if self.dnn_to_bnn_flag:
105+
return_kl = False
106+
99107
batch_size, seq_size, _ = X.size()
100108

101109
hidden_seq = []

0 commit comments

Comments
 (0)