@@ -295,6 +295,15 @@ def __init__(self,
295
295
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
296
296
297
297
self .init_parameters ()
298
+ self .quant_prepare = False
299
+
300
+ def prepare (self ):
301
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
302
+ QConfig (activation = HistogramObserver .with_args (dtype = torch .qint8 ))) for _ in range (5 )])
303
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
304
+ QConfig (activation = HistogramObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
305
+ self .dequant = torch .quantization .DeQuantStub ()
306
+ self .quant_prepare = True
298
307
299
308
def init_parameters (self ):
300
309
self .prior_weight_mu .fill_ (self .prior_mean )
@@ -325,7 +334,8 @@ def forward(self, input, return_kl=True):
325
334
326
335
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
327
336
eps_kernel = self .eps_kernel .data .normal_ ()
328
- weight = self .mu_kernel + (sigma_weight * eps_kernel )
337
+ tmp_result = sigma_weight * eps_kernel
338
+ weight = mu_kernel + tmp_result
329
339
330
340
if return_kl :
331
341
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
@@ -342,6 +352,20 @@ def forward(self, input, return_kl=True):
342
352
343
353
out = F .conv2d (input , weight , bias , self .stride , self .padding ,
344
354
self .dilation , self .groups )
355
+
356
+ if self .quant_prepare :
357
+ # quint8 quantstub
358
+ input = self .quint_quant [0 ](input ) # input
359
+ out = self .quint_quant [1 ](out ) # output
360
+
361
+ # qint8 quantstub
362
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
363
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
364
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
365
+ tmp_result = self .qint_quant [3 ](tmp_result ) # multiply activation
366
+ weight = self .qint_quant [4 ](weight ) # add activatation
367
+
368
+
345
369
if return_kl :
346
370
if self .bias :
347
371
kl = kl_weight + kl_bias
@@ -946,3 +970,12 @@ def forward(self, input, return_kl=True):
946
970
return out , kl
947
971
948
972
return out
973
+
974
+ if __name__ == "__main__" :
975
+ m = Conv2dReparameterization (3 ,3 ,3 )
976
+ m .eval ()
977
+ m .qconfig = torch .quantization .get_default_qconfig ("fbgemm" )
978
+ mp = torch .quantization .prepare (m )
979
+ input = torch .randn (3 ,3 ,4 ,4 )
980
+ mp (input )
981
+ mq = torch .quantization .convert (mp )
0 commit comments