@@ -154,6 +154,9 @@ def init_parameters(self):
154
154
155
155
def forward (self , x , return_kl = True ):
156
156
157
+ if self .dnn_to_bnn_flag :
158
+ return_kl = False
159
+
157
160
# linear outputs
158
161
outputs = F .conv1d (x ,
159
162
weight = self .mu_kernel ,
@@ -173,16 +176,18 @@ def forward(self, x, return_kl=True):
173
176
174
177
delta_kernel = (sigma_weight * eps_kernel )
175
178
176
- kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
177
- self .prior_weight_sigma )
179
+ if return_kl :
180
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
181
+ self .prior_weight_sigma )
178
182
179
183
bias = None
180
184
if self .bias :
181
185
sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
182
186
eps_bias = self .eps_bias .data .normal_ ()
183
187
bias = (sigma_bias * eps_bias )
184
- kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
185
- self .prior_bias_sigma )
188
+ if return_kl :
189
+ kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
190
+ self .prior_bias_sigma )
186
191
187
192
# perturbed feedforward
188
193
perturbed_outputs = F .conv1d (x * sign_input ,
@@ -308,6 +313,9 @@ def init_parameters(self):
308
313
309
314
def forward (self , x , return_kl = True ):
310
315
316
+ if self .dnn_to_bnn_flag :
317
+ return_kl = False
318
+
311
319
# linear outputs
312
320
outputs = F .conv2d (x ,
313
321
weight = self .mu_kernel ,
@@ -327,16 +335,18 @@ def forward(self, x, return_kl=True):
327
335
328
336
delta_kernel = (sigma_weight * eps_kernel )
329
337
330
- kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
331
- self .prior_weight_sigma )
338
+ if return_kl :
339
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
340
+ self .prior_weight_sigma )
332
341
333
342
bias = None
334
343
if self .bias :
335
344
sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
336
345
eps_bias = self .eps_bias .data .normal_ ()
337
346
bias = (sigma_bias * eps_bias )
338
- kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
339
- self .prior_bias_sigma )
347
+ if return_kl :
348
+ kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
349
+ self .prior_bias_sigma )
340
350
341
351
# perturbed feedforward
342
352
perturbed_outputs = F .conv2d (x * sign_input ,
@@ -347,7 +357,6 @@ def forward(self, x, return_kl=True):
347
357
dilation = self .dilation ,
348
358
groups = self .groups ) * sign_output
349
359
350
- self .kl = kl
351
360
# returning outputs + perturbations
352
361
if return_kl :
353
362
return outputs + perturbed_outputs , kl
@@ -462,6 +471,9 @@ def init_parameters(self):
462
471
463
472
def forward (self , x , return_kl = True ):
464
473
474
+ if self .dnn_to_bnn_flag :
475
+ return_kl = False
476
+
465
477
# linear outputs
466
478
outputs = F .conv3d (x ,
467
479
weight = self .mu_kernel ,
@@ -481,16 +493,18 @@ def forward(self, x, return_kl=True):
481
493
482
494
delta_kernel = (sigma_weight * eps_kernel )
483
495
484
- kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
485
- self .prior_weight_sigma )
496
+ if return_kl :
497
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
498
+ self .prior_weight_sigma )
486
499
487
500
bias = None
488
501
if self .bias :
489
502
sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
490
503
eps_bias = self .eps_bias .data .normal_ ()
491
504
bias = (sigma_bias * eps_bias )
492
- kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
493
- self .prior_bias_sigma )
505
+ if return_kl :
506
+ kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
507
+ self .prior_bias_sigma )
494
508
495
509
# perturbed feedforward
496
510
perturbed_outputs = F .conv3d (x * sign_input ,
@@ -612,6 +626,9 @@ def init_parameters(self):
612
626
613
627
def forward (self , x , return_kl = True ):
614
628
629
+ if self .dnn_to_bnn_flag :
630
+ return_kl = False
631
+
615
632
# linear outputs
616
633
outputs = F .conv_transpose1d (x ,
617
634
weight = self .mu_kernel ,
@@ -631,16 +648,18 @@ def forward(self, x, return_kl=True):
631
648
632
649
delta_kernel = (sigma_weight * eps_kernel )
633
650
634
- kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
635
- self .prior_weight_sigma )
651
+ if return_kl :
652
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
653
+ self .prior_weight_sigma )
636
654
637
655
bias = None
638
656
if self .bias :
639
657
sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
640
658
eps_bias = self .eps_bias .data .normal_ ()
641
659
bias = (sigma_bias * eps_bias )
642
- kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
643
- self .prior_bias_sigma )
660
+ if return_kl :
661
+ kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
662
+ self .prior_bias_sigma )
644
663
645
664
# perturbed feedforward
646
665
perturbed_outputs = F .conv_transpose1d (
@@ -767,6 +786,9 @@ def init_parameters(self):
767
786
768
787
def forward (self , x , return_kl = True ):
769
788
789
+ if self .dnn_to_bnn_flag :
790
+ return_kl = False
791
+
770
792
# linear outputs
771
793
outputs = F .conv_transpose2d (x ,
772
794
bias = self .mu_bias ,
@@ -786,16 +808,18 @@ def forward(self, x, return_kl=True):
786
808
787
809
delta_kernel = (sigma_weight * eps_kernel )
788
810
789
- kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
790
- self .prior_weight_sigma )
811
+ if return_kl :
812
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
813
+ self .prior_weight_sigma )
791
814
792
815
bias = None
793
816
if self .bias :
794
817
sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
795
818
eps_bias = self .eps_bias .data .normal_ ()
796
819
bias = (sigma_bias * eps_bias )
797
- kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
798
- self .prior_bias_sigma )
820
+ if return_kl :
821
+ kl = kl + self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu ,
822
+ self .prior_bias_sigma )
799
823
800
824
# perturbed feedforward
801
825
perturbed_outputs = F .conv_transpose2d (
@@ -922,6 +946,9 @@ def init_parameters(self):
922
946
923
947
def forward (self , x , return_kl = True ):
924
948
949
+ if self .dnn_to_bnn_flag :
950
+ return_kl = False
951
+
925
952
# linear outputs
926
953
outputs = F .conv_transpose3d (x ,
927
954
weight = self .mu_kernel ,
@@ -941,8 +968,9 @@ def forward(self, x, return_kl=True):
941
968
942
969
delta_kernel = (sigma_weight * eps_kernel )
943
970
944
- kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
945
- self .prior_weight_sigma )
971
+ if return_kl :
972
+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu ,
973
+ self .prior_weight_sigma )
946
974
947
975
bias = None
948
976
if self .bias :
0 commit comments