@@ -301,9 +301,9 @@ def __init__(self,
301
301
302
302
def prepare (self ):
303
303
self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
304
- QConfig (activation = HistogramObserver .with_args (dtype = torch .qint8 ))) for _ in range (5 )])
304
+ QConfig (weight = HistogramObserver . with_args ( dtype = torch . qint8 ), activation = HistogramObserver .with_args (dtype = torch .qint8 ))) for _ in range (5 )])
305
305
self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
306
- QConfig (activation = HistogramObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
306
+ QConfig (weight = HistogramObserver . with_args ( dtype = torch . quint8 ), activation = HistogramObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
307
307
self .dequant = torch .quantization .DeQuantStub ()
308
308
self .quant_prepare = True
309
309
@@ -337,7 +337,7 @@ def forward(self, input, return_kl=True):
337
337
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
338
338
eps_kernel = self .eps_kernel .data .normal_ ()
339
339
tmp_result = sigma_weight * eps_kernel
340
- weight = mu_kernel + tmp_result
340
+ weight = self . mu_kernel + tmp_result
341
341
342
342
if return_kl :
343
343
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
@@ -976,6 +976,7 @@ def forward(self, input, return_kl=True):
976
976
if __name__ == "__main__" :
977
977
m = Conv2dReparameterization (3 ,3 ,3 )
978
978
m .eval ()
979
+ m .prepare ()
979
980
m .qconfig = torch .quantization .get_default_qconfig ("fbgemm" )
980
981
mp = torch .quantization .prepare (m )
981
982
input = torch .randn (3 ,3 ,4 ,4 )
0 commit comments