Skip to content

Commit 69dc4db

Browse files
committed
pre-sampling for flipout layers
1 parent 18b296f commit 69dc4db

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torch.nn import Parameter
3939
from ..base_variational_layer import BaseVariationalLayer_
4040
from .conv_flipout import *
41+
import random
4142

4243
from torch.distributions.normal import Normal
4344
from torch.distributions.uniform import Uniform
@@ -285,6 +286,9 @@ def __init__(self,
285286

286287
self.is_dequant = False
287288
self.quant_dict = None
289+
self.presampled_input_perturb = None
290+
self.presampled_output_perturb = None
291+
288292

289293
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
290294
""" An implementation for symmetric quantization
@@ -442,8 +446,24 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
442446
self.dilation, self.groups, scale=self.quant_dict[3]['scale'], zero_point=self.quant_dict[3]['zero_point']) # input: quint8, weight: qint8, bias: fp32
443447

444448
# sampling perturbation signs
445-
sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
446-
sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
449+
input_tsize = torch.prod(torch.tensor(x.shape))*1
450+
output_tsize = torch.prod(torch.tensor(outputs.shape))*1
451+
452+
if self.presampled_input_perturb is None:
453+
self.presampled_input_perturb = torch.randint(0, 1, (input_tsize + torch.prod(torch.tensor(x.shape)),)).float()
454+
self.presampled_input_perturb[self.presampled_input_perturb==0] = -1
455+
456+
if self.presampled_output_perturb is None:
457+
self.presampled_output_perturb = torch.randint(0, 1, (output_tsize + torch.prod(torch.tensor(outputs.shape)),)).float()
458+
self.presampled_output_perturb[self.presampled_output_perturb==0] = -1
459+
460+
st = random.randint(0, input_tsize)
461+
sign_input = self.presampled_input_perturb[st:st+torch.prod(torch.tensor(x.shape))].reshape(x.shape)
462+
463+
st = random.randint(0, output_tsize)
464+
sign_output = self.presampled_output_perturb[st:st+torch.prod(torch.tensor(outputs.shape))].reshape(outputs.shape)
465+
# sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
466+
# sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
447467
sign_input = torch.quantize_per_tensor(sign_input, self.quant_dict[4]['scale'], self.quant_dict[4]['zero_point'], torch.quint8)
448468
sign_output = torch.quantize_per_tensor(sign_output, self.quant_dict[5]['scale'], self.quant_dict[5]['zero_point'], torch.quint8)
449469

bayesian_torch/layers/flipout_layers/quantized_linear_flipout.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torch.nn import Module, Parameter
4040
from torch.distributions.normal import Normal
4141
from torch.distributions.uniform import Uniform
42+
import random
4243

4344
from .linear_flipout import LinearFlipout
4445

@@ -55,6 +56,8 @@ def __init__(self,
5556

5657
self.is_dequant = False
5758
self.quant_dict = None
59+
self.presampled_input_perturb = None
60+
self.presampled_output_perturb = None
5861

5962
def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255):
6063
""" An implementation for symmetric quantization
@@ -120,7 +123,7 @@ def quantize(self):
120123
delattr(self, "rho_weight")
121124

122125
self.quantized_mu_bias = self.mu_bias#Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False)
123-
self.quantized_sigma_bias = torch.log1p(torch.exp(self.rho_bias))#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
126+
self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False)#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
124127
delattr(self, "mu_bias")
125128
delattr(self, "rho_bias")
126129

@@ -191,8 +194,27 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
191194
outputs = torch.nn.quantized.functional.linear(x, self.quantized_mu_weight, bias, scale=self.quant_dict[3]['scale'], zero_point=self.quant_dict[3]['zero_point']) # input: quint8, weight: qint8, bias: fp32
192195

193196
# sampling perturbation signs
194-
sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
195-
sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
197+
# sampling perturbation signs
198+
input_tsize = torch.prod(torch.tensor(x.shape))*1
199+
output_tsize = torch.prod(torch.tensor(outputs.shape))*1
200+
201+
if self.presampled_input_perturb is None:
202+
self.presampled_input_perturb = torch.randint(0, 1, (input_tsize + torch.prod(torch.tensor(x.shape)),)).float()
203+
self.presampled_input_perturb[self.presampled_input_perturb==0] = -1
204+
205+
if self.presampled_output_perturb is None:
206+
self.presampled_output_perturb = torch.randint(0, 1, (output_tsize + torch.prod(torch.tensor(outputs.shape)),)).float()
207+
self.presampled_output_perturb[self.presampled_output_perturb==0] = -1
208+
209+
st = random.randint(0, input_tsize)
210+
sign_input = self.presampled_input_perturb[st:st+torch.prod(torch.tensor(x.shape))].reshape(x.shape)
211+
212+
st = random.randint(0, output_tsize)
213+
sign_output = self.presampled_output_perturb[st:st+torch.prod(torch.tensor(outputs.shape))].reshape(outputs.shape)
214+
215+
216+
# sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
217+
# sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
196218
sign_input = torch.quantize_per_tensor(sign_input, self.quant_dict[4]['scale'], self.quant_dict[4]['zero_point'], torch.quint8)
197219
sign_output = torch.quantize_per_tensor(sign_output, self.quant_dict[5]['scale'], self.quant_dict[5]['zero_point'], torch.quint8)
198220

0 commit comments

Comments
 (0)