Skip to content

Commit b80be73

Browse files
committed
feat: add possibility to return no kl, save it as attribute
1 parent 7abcfe7 commit b80be73

File tree

3 files changed

+64
-16
lines changed

3 files changed

+64
-16
lines changed

bayesian_torch/layers/variational_layers/conv_variational.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(self,
112112
self.posterior_rho_init = posterior_rho_init,
113113
self.bias = bias
114114

115+
self.kl = 0
116+
115117
self.mu_kernel = Parameter(
116118
torch.Tensor(out_channels, in_channels // groups, kernel_size))
117119
self.rho_kernel = Parameter(
@@ -160,7 +162,7 @@ def init_parameters(self):
160162
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
161163
std=0.1)
162164

163-
def forward(self, input):
165+
def forward(self, input, return_kl=True):
164166
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
165167
eps_kernel = self.eps_kernel.data.normal_()
166168
weight = self.mu_kernel + (sigma_weight * eps_kernel)
@@ -182,7 +184,11 @@ def forward(self, input):
182184
else:
183185
kl = kl_weight
184186

185-
return out, kl
187+
self.kl = kl
188+
189+
if return_kl:
190+
return out, kl
191+
return out
186192

187193

188194
class Conv2dReparameterization(BaseVariationalLayer_):
@@ -239,6 +245,8 @@ def __init__(self,
239245
self.posterior_rho_init = posterior_rho_init,
240246
self.bias = bias
241247

248+
self.kl = 0
249+
242250
self.mu_kernel = Parameter(
243251
torch.Tensor(out_channels, in_channels // groups, kernel_size,
244252
kernel_size))
@@ -292,7 +300,7 @@ def init_parameters(self):
292300
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
293301
std=0.1)
294302

295-
def forward(self, input):
303+
def forward(self, input, return_kl=True):
296304
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
297305
eps_kernel = self.eps_kernel.data.normal_()
298306
weight = self.mu_kernel + (sigma_weight * eps_kernel)
@@ -313,8 +321,12 @@ def forward(self, input):
313321
kl = kl_weight + kl_bias
314322
else:
315323
kl = kl_weight
324+
325+
self.kl = kl
316326

317-
return out, kl
327+
if return_kl:
328+
return out, kl
329+
return out
318330

319331

320332
class Conv3dReparameterization(BaseVariationalLayer_):
@@ -371,6 +383,8 @@ def __init__(self,
371383
self.posterior_rho_init = posterior_rho_init,
372384
self.bias = bias
373385

386+
self.kl = 0
387+
374388
self.mu_kernel = Parameter(
375389
torch.Tensor(out_channels, in_channels // groups, kernel_size,
376390
kernel_size, kernel_size))
@@ -424,7 +438,7 @@ def init_parameters(self):
424438
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
425439
std=0.1)
426440

427-
def forward(self, input):
441+
def forward(self, input, return_kl=True):
428442
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
429443
eps_kernel = self.eps_kernel.data.normal_()
430444
weight = self.mu_kernel + (sigma_weight * eps_kernel)
@@ -446,7 +460,11 @@ def forward(self, input):
446460
else:
447461
kl = kl_weight
448462

449-
return out, kl
463+
self.kl = kl
464+
465+
if return_kl:
466+
return out, kl
467+
return out
450468

451469

452470
class ConvTranspose1dReparameterization(BaseVariationalLayer_):
@@ -504,6 +522,8 @@ def __init__(self,
504522
self.posterior_rho_init = posterior_rho_init,
505523
self.bias = bias
506524

525+
self.kl = 0
526+
507527
self.mu_kernel = Parameter(
508528
torch.Tensor(in_channels, out_channels // groups, kernel_size))
509529
self.rho_kernel = Parameter(
@@ -552,7 +572,7 @@ def init_parameters(self):
552572
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
553573
std=0.1)
554574

555-
def forward(self, input):
575+
def forward(self, input, return_kl=True):
556576
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
557577
eps_kernel = self.eps_kernel.data.normal_()
558578
weight = self.mu_kernel + (sigma_weight * eps_kernel)
@@ -575,7 +595,11 @@ def forward(self, input):
575595
else:
576596
kl = kl_weight
577597

578-
return out, kl
598+
self.kl = kl
599+
600+
if return_kl:
601+
return out, kl
602+
return out
579603

580604

581605
class ConvTranspose2dReparameterization(BaseVariationalLayer_):
@@ -633,6 +657,8 @@ def __init__(self,
633657
self.posterior_rho_init = posterior_rho_init,
634658
self.bias = bias
635659

660+
self.kl = 0
661+
636662
self.mu_kernel = Parameter(
637663
torch.Tensor(in_channels, out_channels // groups, kernel_size,
638664
kernel_size))
@@ -686,7 +712,7 @@ def init_parameters(self):
686712
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
687713
std=0.1)
688714

689-
def forward(self, input):
715+
def forward(self, input, return_kl=True):
690716
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
691717
eps_kernel = self.eps_kernel.data.normal_()
692718
weight = self.mu_kernel + (sigma_weight * eps_kernel)
@@ -709,7 +735,11 @@ def forward(self, input):
709735
else:
710736
kl = kl_weight
711737

712-
return out, kl
738+
self.kl = kl
739+
740+
if return_kl:
741+
return out, kl
742+
return out
713743

714744

715745
class ConvTranspose3dReparameterization(BaseVariationalLayer_):
@@ -768,6 +798,8 @@ def __init__(self,
768798
self.posterior_rho_init = posterior_rho_init,
769799
self.bias = bias
770800

801+
self.kl = 0
802+
771803
self.mu_kernel = Parameter(
772804
torch.Tensor(in_channels, out_channels // groups, kernel_size,
773805
kernel_size, kernel_size))
@@ -821,7 +853,7 @@ def init_parameters(self):
821853
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
822854
std=0.1)
823855

824-
def forward(self, input):
856+
def forward(self, input, return_kl=True):
825857
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
826858
eps_kernel = self.eps_kernel.data.normal_()
827859
weight = self.mu_kernel + (sigma_weight * eps_kernel)
@@ -844,4 +876,8 @@ def forward(self, input):
844876
else:
845877
kl = kl_weight
846878

847-
return out, kl
879+
self.kl = kl
880+
881+
if return_kl:
882+
return out, kl
883+
return out

bayesian_torch/layers/variational_layers/linear_variational.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(self,
8383
self.posterior_rho_init = posterior_rho_init,
8484
self.bias = bias
8585

86+
self.kl = 0
87+
8688
self.mu_weight = Parameter(torch.Tensor(out_features, in_features))
8789
self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
8890
self.register_buffer('eps_weight',
@@ -124,7 +126,7 @@ def init_parameters(self):
124126
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
125127
std=0.1)
126128

127-
def forward(self, input):
129+
def forward(self, input, return_kl=True):
128130
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
129131
weight = self.mu_weight + \
130132
(sigma_weight * self.eps_weight.data.normal_())
@@ -143,5 +145,9 @@ def forward(self, input):
143145
kl = kl_weight + kl_bias
144146
else:
145147
kl = kl_weight
148+
149+
self.kl = kl
146150

147-
return out, kl
151+
if return_kl:
152+
return out, kl
153+
return out

bayesian_torch/layers/variational_layers/rnn_variational.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __init__(self,
7777
self.posterior_rho_init = posterior_rho_init,
7878
self.bias = bias
7979

80+
self.kl = kl
81+
8082
self.ih = LinearReparameterization(
8183
prior_mean=prior_mean,
8284
prior_variance=prior_variance,
@@ -95,7 +97,7 @@ def __init__(self,
9597
out_features=out_features * 4,
9698
bias=bias)
9799

98-
def forward(self, X, hidden_states=None):
100+
def forward(self, X, hidden_states=None, return_kl=True):
99101

100102
batch_size, seq_size, _ = X.size()
101103

@@ -140,4 +142,8 @@ def forward(self, X, hidden_states=None):
140142
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
141143
c_ts = c_ts.transpose(0, 1).contiguous()
142144

143-
return hidden_seq, (hidden_seq, c_ts), kl
145+
self.kl = kl
146+
147+
if return_kl:
148+
return hidden_seq, (hidden_seq, c_ts), kl
149+
return hidden_seq, (hidden_seq, c_ts)

0 commit comments

Comments
 (0)