Skip to content

Commit 33fef0a

Browse files
committed
Add support for output padding in flipout layers
1 parent 7bf1a2e commit 33fef0a

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def __init__(self,
557557
padding=0,
558558
dilation=1,
559559
groups=1,
560+
output_padding=0,
560561
prior_mean=0,
561562
prior_variance=1,
562563
posterior_mu_init=0,
@@ -588,6 +589,7 @@ def __init__(self,
588589
self.kernel_size = kernel_size
589590
self.stride = stride
590591
self.padding = padding
592+
self.output_padding = output_padding
591593
self.dilation = dilation
592594
self.groups = groups
593595
self.bias = bias
@@ -669,6 +671,7 @@ def forward(self, x, return_kl=True):
669671
bias=self.mu_bias,
670672
stride=self.stride,
671673
padding=self.padding,
674+
output_padding=self.output_padding,
672675
dilation=self.dilation,
673676
groups=self.groups)
674677

@@ -702,6 +705,7 @@ def forward(self, x, return_kl=True):
702705
bias=bias,
703706
stride=self.stride,
704707
padding=self.padding,
708+
output_padding=self.output_padding,
705709
dilation=self.dilation,
706710
groups=self.groups) * sign_output
707711

@@ -719,6 +723,7 @@ def __init__(self,
719723
kernel_size,
720724
stride=1,
721725
padding=0,
726+
output_padding=0,
722727
dilation=1,
723728
groups=1,
724729
prior_mean=0,
@@ -752,6 +757,7 @@ def __init__(self,
752757
self.kernel_size = kernel_size
753758
self.stride = stride
754759
self.padding = padding
760+
self.output_padding = output_padding
755761
self.dilation = dilation
756762
self.groups = groups
757763
self.bias = bias
@@ -837,6 +843,7 @@ def forward(self, x, return_kl=True):
837843
weight=self.mu_kernel,
838844
stride=self.stride,
839845
padding=self.padding,
846+
output_padding=self.output_padding,
840847
dilation=self.dilation,
841848
groups=self.groups)
842849

@@ -870,6 +877,7 @@ def forward(self, x, return_kl=True):
870877
weight=delta_kernel,
871878
stride=self.stride,
872879
padding=self.padding,
880+
output_padding=self.output_padding,
873881
dilation=self.dilation,
874882
groups=self.groups) * sign_output
875883

@@ -887,6 +895,7 @@ def __init__(self,
887895
kernel_size,
888896
stride=1,
889897
padding=0,
898+
output_padding=0,
890899
dilation=1,
891900
groups=1,
892901
prior_mean=0,
@@ -920,6 +929,7 @@ def __init__(self,
920929
self.kernel_size = kernel_size
921930
self.stride = stride
922931
self.padding = padding
932+
self.output_padding = output_padding
923933
self.dilation = dilation
924934
self.groups = groups
925935

@@ -1005,6 +1015,7 @@ def forward(self, x, return_kl=True):
10051015
bias=self.mu_bias,
10061016
stride=self.stride,
10071017
padding=self.padding,
1018+
output_padding=self.output_padding,
10081019
dilation=self.dilation,
10091020
groups=self.groups)
10101021

@@ -1037,6 +1048,7 @@ def forward(self, x, return_kl=True):
10371048
bias=bias,
10381049
stride=self.stride,
10391050
padding=self.padding,
1051+
output_padding=self.output_padding,
10401052
dilation=self.dilation,
10411053
groups=self.groups) * sign_output
10421054

0 commit comments

Comments
 (0)