@@ -152,6 +152,14 @@ def init_parameters(self):
152
152
self .prior_bias_mu .data .fill_ (self .prior_mean )
153
153
self .prior_bias_sigma .data .fill_ (self .prior_variance )
154
154
155
+ def kl_loss (self ):
156
+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
157
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
158
+ if self .bias :
159
+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
160
+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
161
+ return kl
162
+
155
163
def forward (self , x , return_kl = True ):
156
164
157
165
if self .dnn_to_bnn_flag :
@@ -311,6 +319,14 @@ def init_parameters(self):
311
319
self .prior_bias_mu .data .fill_ (self .prior_mean )
312
320
self .prior_bias_sigma .data .fill_ (self .prior_variance )
313
321
322
+ def kl_loss (self ):
323
+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
324
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
325
+ if self .bias :
326
+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
327
+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
328
+ return kl
329
+
314
330
def forward (self , x , return_kl = True ):
315
331
316
332
if self .dnn_to_bnn_flag :
@@ -469,6 +485,14 @@ def init_parameters(self):
469
485
self .prior_bias_mu .data .fill_ (self .prior_mean )
470
486
self .prior_bias_sigma .data .fill_ (self .prior_variance )
471
487
488
+ def kl_loss (self ):
489
+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
490
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
491
+ if self .bias :
492
+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
493
+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
494
+ return kl
495
+
472
496
def forward (self , x , return_kl = True ):
473
497
474
498
if self .dnn_to_bnn_flag :
@@ -624,6 +648,14 @@ def init_parameters(self):
624
648
self .prior_bias_mu .data .fill_ (self .prior_mean )
625
649
self .prior_bias_sigma .data .fill_ (self .prior_variance )
626
650
651
+ def kl_loss (self ):
652
+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
653
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
654
+ if self .bias :
655
+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
656
+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
657
+ return kl
658
+
627
659
def forward (self , x , return_kl = True ):
628
660
629
661
if self .dnn_to_bnn_flag :
@@ -784,6 +816,14 @@ def init_parameters(self):
784
816
self .prior_bias_mu .data .fill_ (self .prior_mean )
785
817
self .prior_bias_sigma .data .fill_ (self .prior_variance )
786
818
819
+ def kl_loss (self ):
820
+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
821
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
822
+ if self .bias :
823
+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
824
+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
825
+ return kl
826
+
787
827
def forward (self , x , return_kl = True ):
788
828
789
829
if self .dnn_to_bnn_flag :
@@ -944,6 +984,14 @@ def init_parameters(self):
944
984
self .prior_bias_mu .data .fill_ (self .prior_mean )
945
985
self .prior_bias_sigma .data .fill_ (self .prior_variance )
946
986
987
+ def kl_loss (self ):
988
+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
989
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
990
+ if self .bias :
991
+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
992
+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
993
+ return kl
994
+
947
995
def forward (self , x , return_kl = True ):
948
996
949
997
if self .dnn_to_bnn_flag :
0 commit comments