39
39
from torch .nn import Module , Parameter
40
40
from torch .distributions .normal import Normal
41
41
from torch .distributions .uniform import Uniform
42
+ import random
42
43
43
44
from .linear_flipout import LinearFlipout
44
45
@@ -55,6 +56,8 @@ def __init__(self,
55
56
56
57
self .is_dequant = False
57
58
self .quant_dict = None
59
+ self .presampled_input_perturb = None
60
+ self .presampled_output_perturb = None
58
61
59
62
def get_scale_and_zero_point (self , x , upper_bound = 100 , target_range = 255 ):
60
63
""" An implementation for symmetric quantization
@@ -120,7 +123,7 @@ def quantize(self):
120
123
delattr (self , "rho_weight" )
121
124
122
125
self .quantized_mu_bias = self .mu_bias #Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False)
123
- self .quantized_sigma_bias = torch .log1p (torch .exp (self .rho_bias ))#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
126
+ self .quantized_sigma_bias = Parameter ( torch .log1p (torch .exp (self .rho_bias )), requires_grad = False )#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False)
124
127
delattr (self , "mu_bias" )
125
128
delattr (self , "rho_bias" )
126
129
@@ -191,8 +194,27 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1
191
194
outputs = torch .nn .quantized .functional .linear (x , self .quantized_mu_weight , bias , scale = self .quant_dict [3 ]['scale' ], zero_point = self .quant_dict [3 ]['zero_point' ]) # input: quint8, weight: qint8, bias: fp32
192
195
193
196
# sampling perturbation signs
194
- sign_input = torch .zeros (x .shape ).uniform_ (- 1 , 1 ).sign ()
195
- sign_output = torch .zeros (outputs .shape ).uniform_ (- 1 , 1 ).sign ()
197
+ # sampling perturbation signs
198
+ input_tsize = torch .prod (torch .tensor (x .shape ))* 1
199
+ output_tsize = torch .prod (torch .tensor (outputs .shape ))* 1
200
+
201
+ if self .presampled_input_perturb is None :
202
+ self .presampled_input_perturb = torch .randint (0 , 1 , (input_tsize + torch .prod (torch .tensor (x .shape )),)).float ()
203
+ self .presampled_input_perturb [self .presampled_input_perturb == 0 ] = - 1
204
+
205
+ if self .presampled_output_perturb is None :
206
+ self .presampled_output_perturb = torch .randint (0 , 1 , (output_tsize + torch .prod (torch .tensor (outputs .shape )),)).float ()
207
+ self .presampled_output_perturb [self .presampled_output_perturb == 0 ] = - 1
208
+
209
+ st = random .randint (0 , input_tsize )
210
+ sign_input = self .presampled_input_perturb [st :st + torch .prod (torch .tensor (x .shape ))].reshape (x .shape )
211
+
212
+ st = random .randint (0 , output_tsize )
213
+ sign_output = self .presampled_output_perturb [st :st + torch .prod (torch .tensor (outputs .shape ))].reshape (outputs .shape )
214
+
215
+
216
+ # sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign()
217
+ # sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign()
196
218
sign_input = torch .quantize_per_tensor (sign_input , self .quant_dict [4 ]['scale' ], self .quant_dict [4 ]['zero_point' ], torch .quint8 )
197
219
sign_output = torch .quantize_per_tensor (sign_output , self .quant_dict [5 ]['scale' ], self .quant_dict [5 ]['zero_point' ], torch .quint8 )
198
220
0 commit comments