@@ -38,13 +38,32 @@ def enable_calibration_quantization(model, quantizer_type='fake_quant'):
3838 submodule .enable_fake_quant ()
3939
4040
41- def enable_quantization (model ):
41+ def enable_quantization (model , weight_cali_on = False , act_cali_on = False ):
42+ '''
43+ We enable all quantization for quantization aware training.
44+ But we sometimes remain weight calibration on for update minmax all along.
45+ For some hardware, there is no weight quant param to be set, which mean it will calculate
46+ min / max for weight.
47+ Assume weight scale * 127 > abs(weight).max() after some training. Training scale and deploy
48+ scale can be various, so we have to update range every iter.
49+ '''
4250 logger .info ('Disable observer and Enable quantize.' )
51+ if weight_cali_on :
52+ logger .info ('Enable observer for weight.' )
53+ if act_cali_on :
54+ logger .info ('Enable observer for activation.' )
4355 for name , submodule in model .named_modules ():
4456 if isinstance (submodule , torch .quantization .FakeQuantizeBase ):
45- logger .debug ('Disable observer and Enable quant: {}' .format (name ))
46- submodule .disable_observer ()
4757 submodule .enable_fake_quant ()
58+ if weight_cali_on and 'weight_fake_quant' in name :
59+ logger .debug ('Enable observer and Enable quant: {}' .format (name ))
60+ submodule .enable_observer ()
61+ elif act_cali_on and 'act_fake_quant' in name :
62+ logger .debug ('Enable observer and Enable quant: {}' .format (name ))
63+ submodule .enable_observer ()
64+ else :
65+ logger .debug ('Disable observer and Enable quant: {}' .format (name ))
66+ submodule .disable_observer ()
4867
4968
5069def disable_all (model ):
0 commit comments