Skip to content

Commit e2f6d78

Browse files
Tracinzhangqi3
andauthored
[Observer] Update state to enable oberver while training. (#227)
* [Observer] Update state to enable oberver while training. Co-authored-by: zhangqi3 <[email protected]>
1 parent b5fb351 commit e2f6d78

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

mqbench/utils/state.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,32 @@ def enable_calibration_quantization(model, quantizer_type='fake_quant'):
3838
submodule.enable_fake_quant()
3939

4040

41-
def enable_quantization(model):
41+
def enable_quantization(model, weight_cali_on=False, act_cali_on=False):
42+
'''
43+
We enable all quantization for quantization aware training.
44+
But we sometimes remain weight calibration on for update minmax all along.
45+
For some hardware, there is no weight quant param to be set, which mean it will calculate
46+
min / max for weight.
47+
Assume weight scale * 127 > abs(weight).max() after some training. Training scale and deploy
48+
scale can be various, so we have to update range every iter.
49+
'''
4250
logger.info('Disable observer and Enable quantize.')
51+
if weight_cali_on:
52+
logger.info('Enable observer for weight.')
53+
if act_cali_on:
54+
logger.info('Enable observer for activation.')
4355
for name, submodule in model.named_modules():
4456
if isinstance(submodule, torch.quantization.FakeQuantizeBase):
45-
logger.debug('Disable observer and Enable quant: {}'.format(name))
46-
submodule.disable_observer()
4757
submodule.enable_fake_quant()
58+
if weight_cali_on and 'weight_fake_quant' in name:
59+
logger.debug('Enable observer and Enable quant: {}'.format(name))
60+
submodule.enable_observer()
61+
elif act_cali_on and 'act_fake_quant' in name:
62+
logger.debug('Enable observer and Enable quant: {}'.format(name))
63+
submodule.enable_observer()
64+
else:
65+
logger.debug('Disable observer and Enable quant: {}'.format(name))
66+
submodule.disable_observer()
4867

4968

5069
def disable_all(model):

0 commit comments

Comments
 (0)