-
Notifications
You must be signed in to change notification settings - Fork 18
Error in Quantizing ResNet18 #20
Description
Hi,
I tried to quantize the standard resnet18 model from torchvision and I am getting an error. The code snippet and error message are below.
import torch
import torch.nn as nn
import nemo
from torchvision.models import resnet18
model = resnet18().cuda()
precision_dict = {}
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Conv2d):
precision_dict[name] = {"W_bits": 4}
elif isinstance(mod, torch.nn.ReLU):
precision_dict[name] = {"x_bits": 4}
elif isinstance(mod, torch.nn.BatchNorm2d):
precision_dict[name] = {"W_bits": 4}
model = nemo.transform.quantize_pact(model,
dummy_input=torch.randn((1,3,224,224)).to(0))
model.change_precision(bits=1, min_prec_dict=precision_dict)
model.reset_alpha_weights()
model.reset_alpha_act()
# model.fold_bn()
input_ = torch.ones(1, 3, 224, 224).cuda()
output_ = model(input_)
model.qd_stage(eps_in=1./255, add_input_bias_dict=None,
remove_bias_dict=None,
precision=nemo.precision.Precision(bits=8), int_accurate=True,
limit_at_32_bits=False, postpone_bn_hardening=False)
model = nemo.transform.integerize_pact(model, eps_in=1.0/255)
ValueError Traceback (most recent call last)
in
27 remove_bias_dict=None,
28 precision=nemo.precision.Precision(bits=8), int_accurate=True,
---> 29 limit_at_32_bits=False, postpone_bn_hardening=False)
30
31 model = nemo.transform.integerize_pact(model, eps_in=1.0/255)nemo/nemo/transf/deploy.py in _qd_stage(self, eps_in, add_input_bias_dict, remove_bias_dict, prune_empty_bn, int_accurate, bn_calibration_fn, bn_calibration_range_factor, **kwargs)
165 self.calibrate_bn(**kwargs)
166 else:
--> 167 self.calibrate_bn(minmax=False, range_factor=bn_calibration_range_factor, **kwargs)
168 self.set_deployment(eps_in=eps_in, **kwargs) # repeat, to fix BN eps
169 self.harden_weights()nemo/nemo/transf/bn.py in _calibrate_bn_pact(self, calib_dict, kappa_bit_default, lamda_bit_default, kappa_dict, lamda_dict, range_factor, minmax, kwargs)
119 eps_lim = max(out_range, lamda_max) / (2(lamda_bit_lim - 1) - 1)
120 eps_kappa_lim = eps_lim / m.eps_in
--> 121 kappa_bits = min(int(min(torch.log2(1 + 2*kappa_max / eps_kappa_lim).floor(), kappa_bit_lim)), BN_PRECISION_MAX)
122 lamda_bits = min(lamda_bit_lim, BN_PRECISION_MAX)
123 m.precision_kappa.set_bits(kappa_bits)ValueError: cannot convert float NaN to integer