Skip to content

Commit 97ba16a

Browse files
correcting the order of group and dilation parameters in Conv transpose layers.
Fix issue #21 Signed-off-by: Ranganath Krishnan <[email protected]>
1 parent 1180b87 commit 97ba16a

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,8 @@ def forward(self, x, return_kl=True):
769769
stride=self.stride,
770770
padding=self.padding,
771771
output_padding=self.output_padding,
772-
dilation=self.dilation,
773-
groups=self.groups)
772+
groups=self.groups,
773+
dilation=self.dilation)
774774

775775
# sampling perturbation signs
776776
sign_input = x.clone().uniform_(-1, 1).sign()
@@ -803,8 +803,8 @@ def forward(self, x, return_kl=True):
803803
stride=self.stride,
804804
padding=self.padding,
805805
output_padding=self.output_padding,
806-
dilation=self.dilation,
807-
groups=self.groups)
806+
groups=self.groups,
807+
dilation=self.dilation)
808808
perturbed_outputs = perturbed_outputs_tmp * sign_output
809809
out = outputs + perturbed_outputs
810810

@@ -968,8 +968,8 @@ def forward(self, x, return_kl=True):
968968
stride=self.stride,
969969
padding=self.padding,
970970
output_padding=self.output_padding,
971-
dilation=self.dilation,
972-
groups=self.groups)
971+
groups=self.groups,
972+
dilation=self.dilation)
973973

974974
# sampling perturbation signs
975975
sign_input = x.clone().uniform_(-1, 1).sign()
@@ -1002,8 +1002,8 @@ def forward(self, x, return_kl=True):
10021002
stride=self.stride,
10031003
padding=self.padding,
10041004
output_padding=self.output_padding,
1005-
dilation=self.dilation,
1006-
groups=self.groups)
1005+
groups=self.groups,
1006+
dilation=self.dilation)
10071007
perturbed_outputs = perturbed_outputs_tmp * sign_output
10081008
out = outputs + perturbed_outputs
10091009

@@ -1167,8 +1167,8 @@ def forward(self, x, return_kl=True):
11671167
stride=self.stride,
11681168
padding=self.padding,
11691169
output_padding=self.output_padding,
1170-
dilation=self.dilation,
1171-
groups=self.groups)
1170+
groups=self.groups,
1171+
dilation=self.dilation)
11721172

11731173
# sampling perturbation signs
11741174
sign_input = x.clone().uniform_(-1, 1).sign()
@@ -1200,8 +1200,8 @@ def forward(self, x, return_kl=True):
12001200
stride=self.stride,
12011201
padding=self.padding,
12021202
output_padding=self.output_padding,
1203-
dilation=self.dilation,
1204-
groups=self.groups)
1203+
groups=self.groups,
1204+
dilation=self.dilation)
12051205
perturbed_outputs = perturbed_outputs_tmp * sign_output
12061206
out = outputs + perturbed_outputs
12071207

bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
898898

899899
self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(self.quantized_mu_weight, bias, self.stride,
900900
self.padding, self.output_padding,
901-
self.dilation, self.groups)
901+
self.groups, self.dilation)
902902

903903
outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)
904904

@@ -923,7 +923,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
923923

924924
self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(delta_kernel, bias, self.stride,
925925
self.padding, self.output_padding,
926-
self.dilation, self.groups)
926+
self.groups, self.dilation)
927927
perturbed_outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)
928928

929929
perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
@@ -1106,7 +1106,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
11061106

11071107
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(self.quantized_mu_weight, bias, self.stride,
11081108
self.padding, self.output_padding,
1109-
self.dilation, self.groups)
1109+
self.groups, self.dilation)
11101110

11111111
outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)
11121112

@@ -1131,7 +1131,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
11311131

11321132
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(delta_kernel, bias, self.stride,
11331133
self.padding, self.output_padding,
1134-
self.dilation, self.groups)
1134+
self.groups, self.dilation)
11351135
perturbed_outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)
11361136

11371137
perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
@@ -1314,7 +1314,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
13141314

13151315
self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(self.quantized_mu_weight, bias, self.stride,
13161316
self.padding, self.output_padding,
1317-
self.dilation, self.groups)
1317+
self.groups, self.dilation)
13181318

13191319
outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)
13201320

@@ -1339,7 +1339,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
13391339

13401340
self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(delta_kernel, bias, self.stride,
13411341
self.padding, self.output_padding,
1342-
self.dilation, self.groups)
1342+
self.groups, self.dilation)
13431343
perturbed_outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)
13441344

13451345
perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)

bayesian_torch/layers/variational_layers/conv_variational.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def forward(self, input, return_kl=True):
719719

720720
out = F.conv_transpose1d(input, weight, bias, self.stride,
721721
self.padding, self.output_padding,
722-
self.dilation, self.groups)
722+
self.groups, self.dilation)
723723

724724
if self.quant_prepare:
725725
# quint8 quantstub
@@ -894,7 +894,7 @@ def forward(self, input, return_kl=True):
894894

895895
out = F.conv_transpose2d(input, weight, bias, self.stride,
896896
self.padding, self.output_padding,
897-
self.dilation, self.groups)
897+
self.groups, self.dilation)
898898

899899
if self.quant_prepare:
900900
# quint8 quantstub
@@ -1070,7 +1070,7 @@ def forward(self, input, return_kl=True):
10701070

10711071
out = F.conv_transpose3d(input, weight, bias, self.stride,
10721072
self.padding, self.output_padding,
1073-
self.dilation, self.groups)
1073+
self.groups, self.dilation)
10741074

10751075
if self.quant_prepare:
10761076
# quint8 quantstub

bayesian_torch/layers/variational_layers/quantize_conv_variational.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
996996

997997
out = F.conv_transpose1d(input, weight, bias, self.stride,
998998
self.padding, self.output_padding,
999-
self.dilation, self.groups)
999+
self.groups, self.dilation)
10001000

10011001
else:
10021002
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
@@ -1019,7 +1019,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
10191019

10201020
self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(weight, bias, self.stride,
10211021
self.padding, self.output_padding,
1022-
self.dilation, self.groups)
1022+
self.groups, self.dilation)
10231023

10241024
out = torch.ops.quantized.conv_transpose1d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)
10251025

@@ -1227,7 +1227,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
12271227

12281228
out = F.conv_transpose2d(input, weight, bias, self.stride,
12291229
self.padding, self.output_padding,
1230-
self.dilation, self.groups)
1230+
self.groups, self.dilation)
12311231

12321232
else:
12331233
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
@@ -1250,7 +1250,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
12501250

12511251
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(weight, bias, self.stride,
12521252
self.padding, self.output_padding,
1253-
self.dilation, self.groups)
1253+
self.groups, self.dilation)
12541254

12551255
out = torch.ops.quantized.conv_transpose2d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)
12561256

@@ -1458,7 +1458,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
14581458

14591459
out = F.conv_transpose3d(input, weight, bias, self.stride,
14601460
self.padding, self.output_padding,
1461-
self.dilation, self.groups)
1461+
self.groups, self.dilation)
14621462

14631463
else:
14641464
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
@@ -1481,7 +1481,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s
14811481

14821482
self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(weight, bias, self.stride,
14831483
self.padding, self.output_padding,
1484-
self.dilation, self.groups)
1484+
self.groups, self.dilation)
14851485

14861486
out = torch.ops.quantized.conv_transpose3d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)
14871487

0 commit comments

Comments
 (0)