@@ -112,6 +112,8 @@ def __init__(self,
112
112
self .posterior_rho_init = posterior_rho_init ,
113
113
self .bias = bias
114
114
115
+ self .kl = 0
116
+
115
117
self .mu_kernel = Parameter (
116
118
torch .Tensor (out_channels , in_channels // groups , kernel_size ))
117
119
self .rho_kernel = Parameter (
@@ -160,7 +162,7 @@ def init_parameters(self):
160
162
self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
161
163
std = 0.1 )
162
164
163
- def forward (self , input ):
165
+ def forward (self , input , return_kl = True ):
164
166
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
165
167
eps_kernel = self .eps_kernel .data .normal_ ()
166
168
weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -182,7 +184,11 @@ def forward(self, input):
182
184
else :
183
185
kl = kl_weight
184
186
185
- return out , kl
187
+ self .kl = kl
188
+
189
+ if return_kl :
190
+ return out , kl
191
+ return out
186
192
187
193
188
194
class Conv2dReparameterization (BaseVariationalLayer_ ):
@@ -239,6 +245,8 @@ def __init__(self,
239
245
self .posterior_rho_init = posterior_rho_init ,
240
246
self .bias = bias
241
247
248
+ self .kl = 0
249
+
242
250
self .mu_kernel = Parameter (
243
251
torch .Tensor (out_channels , in_channels // groups , kernel_size ,
244
252
kernel_size ))
@@ -292,7 +300,7 @@ def init_parameters(self):
292
300
self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
293
301
std = 0.1 )
294
302
295
- def forward (self , input ):
303
+ def forward (self , input , return_kl = True ):
296
304
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
297
305
eps_kernel = self .eps_kernel .data .normal_ ()
298
306
weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -313,8 +321,12 @@ def forward(self, input):
313
321
kl = kl_weight + kl_bias
314
322
else :
315
323
kl = kl_weight
324
+
325
+ self .kl = kl
316
326
317
- return out , kl
327
+ if return_kl :
328
+ return out , kl
329
+ return out
318
330
319
331
320
332
class Conv3dReparameterization (BaseVariationalLayer_ ):
@@ -371,6 +383,8 @@ def __init__(self,
371
383
self .posterior_rho_init = posterior_rho_init ,
372
384
self .bias = bias
373
385
386
+ self .kl = 0
387
+
374
388
self .mu_kernel = Parameter (
375
389
torch .Tensor (out_channels , in_channels // groups , kernel_size ,
376
390
kernel_size , kernel_size ))
@@ -424,7 +438,7 @@ def init_parameters(self):
424
438
self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
425
439
std = 0.1 )
426
440
427
- def forward (self , input ):
441
+ def forward (self , input , return_kl = True ):
428
442
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
429
443
eps_kernel = self .eps_kernel .data .normal_ ()
430
444
weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -446,7 +460,11 @@ def forward(self, input):
446
460
else :
447
461
kl = kl_weight
448
462
449
- return out , kl
463
+ self .kl = kl
464
+
465
+ if return_kl :
466
+ return out , kl
467
+ return out
450
468
451
469
452
470
class ConvTranspose1dReparameterization (BaseVariationalLayer_ ):
@@ -504,6 +522,8 @@ def __init__(self,
504
522
self .posterior_rho_init = posterior_rho_init ,
505
523
self .bias = bias
506
524
525
+ self .kl = 0
526
+
507
527
self .mu_kernel = Parameter (
508
528
torch .Tensor (in_channels , out_channels // groups , kernel_size ))
509
529
self .rho_kernel = Parameter (
@@ -552,7 +572,7 @@ def init_parameters(self):
552
572
self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
553
573
std = 0.1 )
554
574
555
- def forward (self , input ):
575
+ def forward (self , input , return_kl = True ):
556
576
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
557
577
eps_kernel = self .eps_kernel .data .normal_ ()
558
578
weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -575,7 +595,11 @@ def forward(self, input):
575
595
else :
576
596
kl = kl_weight
577
597
578
- return out , kl
598
+ self .kl = kl
599
+
600
+ if return_kl :
601
+ return out , kl
602
+ return out
579
603
580
604
581
605
class ConvTranspose2dReparameterization (BaseVariationalLayer_ ):
@@ -633,6 +657,8 @@ def __init__(self,
633
657
self .posterior_rho_init = posterior_rho_init ,
634
658
self .bias = bias
635
659
660
+ self .kl = 0
661
+
636
662
self .mu_kernel = Parameter (
637
663
torch .Tensor (in_channels , out_channels // groups , kernel_size ,
638
664
kernel_size ))
@@ -686,7 +712,7 @@ def init_parameters(self):
686
712
self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
687
713
std = 0.1 )
688
714
689
- def forward (self , input ):
715
+ def forward (self , input , return_kl = True ):
690
716
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
691
717
eps_kernel = self .eps_kernel .data .normal_ ()
692
718
weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -709,7 +735,11 @@ def forward(self, input):
709
735
else :
710
736
kl = kl_weight
711
737
712
- return out , kl
738
+ self .kl = kl
739
+
740
+ if return_kl :
741
+ return out , kl
742
+ return out
713
743
714
744
715
745
class ConvTranspose3dReparameterization (BaseVariationalLayer_ ):
@@ -768,6 +798,8 @@ def __init__(self,
768
798
self .posterior_rho_init = posterior_rho_init ,
769
799
self .bias = bias
770
800
801
+ self .kl = 0
802
+
771
803
self .mu_kernel = Parameter (
772
804
torch .Tensor (in_channels , out_channels // groups , kernel_size ,
773
805
kernel_size , kernel_size ))
@@ -821,7 +853,7 @@ def init_parameters(self):
821
853
self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
822
854
std = 0.1 )
823
855
824
- def forward (self , input ):
856
+ def forward (self , input , return_kl = True ):
825
857
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
826
858
eps_kernel = self .eps_kernel .data .normal_ ()
827
859
weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -844,4 +876,8 @@ def forward(self, input):
844
876
else :
845
877
kl = kl_weight
846
878
847
- return out , kl
879
+ self .kl = kl
880
+
881
+ if return_kl :
882
+ return out , kl
883
+ return out
0 commit comments