66import torch .nn as nn
77import torch .nn .functional as F
88from torch .fx import GraphModule , Node
9+ from torch .quantization .observer import ObserverBase
10+
911
10- from mqbench .observer import MinMaxObserver , ObserverBase
1112from mqbench .utils import deepcopy_graphmodule
13+ from mqbench .utils .state import enable_quantization , disable_all
14+ from mqbench .utils .logger import logger
15+
1216
13- _ADAROUND_SUPPORT_TYPE = (nn .Conv2d , nn .Linear , )
17+ __all__ = ['adaround' ]
18+ _ADAROUND_SUPPORT_TYPE = (nn .Conv2d , nn .Linear )
1419
1520
1621def lp_norm (prediction , target , p = 2.0 ):
@@ -26,6 +31,7 @@ def lp_norm(prediction, target, p=2.0):
2631 """
2732 return (prediction - target ).abs ().pow (p ).sum (1 ).mean ()
2833
34+
2935def _rectified_sigmoid (x , zeta , gamma ):
3036 """Function to generate rounding mask.
3137
@@ -39,60 +45,28 @@ def _rectified_sigmoid(x, zeta, gamma):
3945 """
4046 return ((zeta - gamma ) * torch .sigmoid (x ) + gamma ).clamp (0 , 1 )
4147
42- def get_cali_samples (train_data_loader , num_samples , no_label = True ):
43- """Generate sub-dataset for calibration.
44-
45- Args:
46- train_data_loader (torch.utils.data.DataLoader):
47- num_samples (int):
48- no_label (bool, optional): If the dataloader has no labels. Defaults to True.
4948
50- Returns:
51- torch.Tensor: Concatenated data matrix.
52- """
53- cali_data_list = []
54- if no_label :
55- for batch_data in train_data_loader :
56- cali_data_list .append (batch_data ["image" ])
57- if len (cali_data_list ) >= num_samples :
58- break
59- else :
60- for batch_data , _ in train_data_loader :
61- cali_data_list .append (batch_data )
62- if len (cali_data_list ) >= num_samples :
63- break
64- return torch .cat (cali_data_list , dim = 0 )[:num_samples ].cpu ()
65-
66- def adaround (model : GraphModule , train_data , n_samples : int = 128 ,
67- lr : float = 4e-3 , batch_size : int = 128 , max_iter : int = 8000 ,
68- weight : float = 0.01 , beta : float = 20 , gamma : float = - 0.1 , zeta : float = 1.1 ,
69- quant_min : int = - 128 , quant_max : int = 127 , per_channel : bool = False ):
49+ def adaround (model : GraphModule , cali_data ,
50+ lr : float = 0.001 , batch_size : int = 128 , max_iter : int = 8000 ,
51+ weight : float = 0.01 , beta : float = 20 , gamma : float = - 0.1 , zeta : float = 1.1 ):
7052 """Main function to run AdaRound on a given model.
7153
7254 Args:
73- model (GraphModule):
74- train_data (torch.utils.data.DataLoader):
75- n_samples (int, optional): Defaults to 128.
76- lr (float, optional): Defaults to 4e-3.
55+ model (GraphModule): Model to adaround.
56+ cali_data (torch.tensor): Stacked tensor.
57+ lr (float, optional): Defaults to 0.001.
7758 batch_size (int, optional): Defaults to 128.
7859 max_iter (int, optional): Defaults to 8000.
7960 weight (float, optional): Defaults to 0.01.
8061 beta (float, optional): Defaults to 20.
8162 gamma (float, optional): Defaults to -0.1.
8263 zeta (float, optional): Defaults to 1.1.
83- quant_min (int, optional): Defaults to -128.
84- quant_max (int, optional): Defaults to 127.
85- per_channel (bool, optional): Defaults to False.
8664
8765 Returns:
8866 GraphModule: Modified copy of the given model.
8967 """
90- model .cpu ()
91- print ("AdaRound: Quant-Range="
92- "[{}, {}], Per-Channel={}" .format (quant_min , quant_max , per_channel ))
93-
94- # sample data from training data
95- cali_data = get_cali_samples (train_data , n_samples )
68+ device = cali_data .device
69+ model .to (device )
9670
9771 # apply rewritten deepcopy of GraphModule
9872 quant_model = deepcopy_graphmodule (model )
@@ -103,50 +77,33 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
10377 fp_observer_binding_dict = _insert_observer (model , "output" )
10478 quant_observer_binding_dict = _insert_observer (quant_model , "input" )
10579
106- print ("Record Outputs (by CPU) ..." )
80+ logger . info ("Record Outputs ..." )
10781 # apply data to record output
82+ disable_all (model )
83+ enable_quantization (quant_model )
84+
10885 saver = FpOutputSaver (model , observer_binding_dict = fp_observer_binding_dict ,
10986 input_data = cali_data )
11087
11188 # get layers for reconstruction
11289 modules = dict (quant_model .named_modules ())
11390 quant_module_name_list = _get_quant_modules_by_topology (quant_model )
11491
115- # TODO: more observer types / affine mode
116- if per_channel :
117- qscheme = torch .per_channel_symmetric
118- ch_axis = 0
119- else :
120- qscheme = torch .per_tensor_symmetric
121- ch_axis = - 1
122-
123- observer_type = MinMaxObserver .with_args (dtype = torch .qint8 , quant_min = quant_min , quant_max = quant_max ,
124- reduce_range = False , qscheme = qscheme , ch_axis = ch_axis )
125-
126- scale_dict = _init_weight_scale (quant_model , quant_observer_binding_dict .keys (), observer_type )
127-
12892 # disable gradient for all parameters
129- for n , m in quant_model .named_modules ():
130- if hasattr (m , "weight" ):
131- m .weight .requires_grad = False
132- if hasattr (m , "bias" ) and getattr (m , "bias" ) is not None :
133- m .bias .requires_grad = False
134-
135- quant_model .cuda ()
136- cali_data = cali_data .cuda ()
93+ for p in quant_model .parameters ():
94+ p .requires_grad = False
13795
13896 # learn the rounding mask for each layer
13997 for node_name in quant_module_name_list :
140- print ( "===> Train for Layer: {}" .format (node_name ))
98+ logger . info ( "Adaround for Layer: {}" .format (node_name ))
14199 # get input and output tensors
142- output_tensor = saver .get_result_by_name (node_name ).cuda ( )
100+ output_tensor = saver .get_result_by_name (node_name ).to ( device )
143101 input_observer = modules [quant_observer_binding_dict [node_name ].name ]
144102 cur_node = _get_node_by_name (quant_model , node_name )
145103 if cur_node is not None :
146104 module = modules [cur_node .target ]
147105 else :
148106 raise RuntimeError ("Node not found in graph." )
149- module .eval ()
150107
151108 with _Recorder (input_observer ):
152109 with torch .no_grad ():
@@ -158,12 +115,14 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
158115 ada_reg_loss = AdaRoundReg (zeta = zeta , gamma = gamma , weight = weight ,
159116 temp_anneal = temp_anneal , h_func = _rectified_sigmoid )
160117
161- scale , zero_point = scale_dict [node_name ]
162- ada_quantizer = AdaRoundQuantizer (reg = ada_reg_loss , ch_axis = ch_axis ,
163- scale = scale , zero_point = zero_point ,
164- quant_min = quant_min , quant_max = quant_max )
118+ weight_fake_quant = module .weight_fake_quant
119+ ch_axis = weight_fake_quant .activation_post_process .ch_axis
120+ scale , zero_point = weight_fake_quant .activation_post_process .calculate_qparams ()
121+ quant_min , quant_max = weight_fake_quant .activation_post_process ._calculate_qmin_qmax ()
122+ ada_quantizer = AdaRoundQuantizer (reg = ada_reg_loss , scale = scale , zero_point = zero_point ,
123+ quant_min = quant_min , quant_max = quant_max , ch_axis = ch_axis )
165124
166- ada_layer = AdaRoundLayer (module , ada_reg_loss , ada_quantizer ).cuda ( )
125+ ada_layer = AdaRoundLayer (module , ada_reg_loss , ada_quantizer ).to ( device )
167126
168127 alpha = learning_alpha (input_tensor , output_tensor ,
169128 ada_layer , ada_reg_loss , lr ,
@@ -173,9 +132,26 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
173132 module .weight .data = ada_quantizer (module .weight , alpha )
174133 module .weight .requires_grad = False
175134
135+ _del_tensor_observer (quant_model , quant_observer_binding_dict )
136+
176137 return quant_model
177138
178139
140+ def _del_tensor_observer (gm : GraphModule , observer_binding_dict ):
141+ modules = dict (gm .named_modules ())
142+ nodes = list (gm .graph .nodes )
143+ # Quant model tensor observer insert in 'input' mode.
144+ for node in observer_binding_dict .values ():
145+ delattr (gm , node .name )
146+ for _node in list (node .users .keys ()):
147+ _node .args = node .args
148+ for node in observer_binding_dict .values ():
149+ gm .graph .erase_node (node )
150+
151+ gm .recompile ()
152+ gm .graph .lint ()
153+
154+
179155def _insert_observer (gm : GraphModule , insert_type = "input" ):
180156 """Insert observers to record the input and output of target layers.
181157
@@ -260,7 +236,7 @@ class FpOutputSaver:
260236 @torch .no_grad ()
261237 def __init__ (self , fp_gm : GraphModule ,
262238 observer_binding_dict : Dict [str , Node ],
263- save_loc = "disk" , root = "./calibration " ,
239+ save_loc = "disk" , root = "./cali_data_cache " ,
264240 input_data = None ):
265241 """
266242 Currently, there are two options provided to save floating point model
@@ -283,8 +259,8 @@ def __init__(self, fp_gm: GraphModule,
283259 self ._data = dict ()
284260
285261 if self .save_loc == "disk" and not os .path .exists (self .data_root ):
286- raise NotADirectoryError ( "The given path is not a folder."
287- "Ensure you give the correct path." )
262+ logger . info ( 'Save data on disk, create directory {}' . format ( self . data_root ))
263+ os . mkdir ( self . data_root )
288264 saving_operation = self ._disk_saving_operation \
289265 if self .save_loc == "disk" else self ._gpu_saving_operation
290266
@@ -352,30 +328,6 @@ def _get_quant_modules_by_topology(gm: GraphModule):
352328 module_name_list .append (node .name )
353329 return module_name_list
354330
355- def _init_weight_scale (gm : GraphModule , observed_module_list , observer_type : Callable ):
356- """Simulate the fake quant modules to calculate scales and zero-points.
357-
358- Args:
359- gm (GraphModule):
360- observed_module_list (list):
361- observer_type (Callable):
362-
363- Returns:
364- dict:
365- """
366- scale_dict = dict ()
367- modules = dict (gm .named_modules ())
368-
369- for name in observed_module_list :
370- node = _get_node_by_name (gm , name )
371- if node .op == "call_module" :
372- observer = observer_type ()
373- module = modules [node .target ]
374- weight = module .weight
375- observer (weight )
376- scale , zero_point = observer .calculate_qparams ()
377- scale_dict [name ] = (scale .cuda ().detach (), zero_point .cuda ().detach ())
378- return scale_dict
379331
380332def _get_node_by_name (gm : GraphModule , node_name : str ):
381333 """
@@ -446,8 +398,8 @@ def __call__(self, t):
446398
447399
448400class AdaRoundQuantizer :
449- def __init__ (self , reg : AdaRoundReg , ch_axis : int ,
450- scale , zero_point , quant_min = - 128 , quant_max = 127 ,
401+ def __init__ (self , reg : AdaRoundReg , scale , zero_point ,
402+ quant_min = - 128 , quant_max = 127 , ch_axis = - 1 ,
451403 soft = True ):
452404 self .quant_min = quant_min
453405 self .quant_max = quant_max
@@ -465,11 +417,6 @@ def __init__(self, reg: AdaRoundReg, ch_axis: int,
465417 def __call__ (self , w , alpha ):
466418 scale = self .scale
467419 zero_point = self .zero_point
468- if self .ch_axis != - 1 :
469- new_shape = [1 ] * len (w .shape )
470- new_shape [self .ch_axis ] = w .shape [self .ch_axis ]
471- scale = self .scale .reshape (new_shape )
472- zero_point = self .zero_point .reshape (new_shape )
473420
474421 if self .soft_quantize :
475422 w = (w / scale ).floor () + self .h_func (alpha , self .zeta , self .gamma )
@@ -483,15 +430,6 @@ def __call__(self, w, alpha):
483430 w = w * scale
484431 return w
485432
486- def __repr__ (self ):
487- scale = self .scale .item ()
488- if self .ch_axis != - 1 :
489- scale = "per-channel scale of " + str (tuple (self .scale .shape ))
490- repr_str = "AdaRoundQuantizer(quant_min={}, quant_max={}, scale={}, " \
491- "gamma={}, zeta={}, soft_quantize={})" .format (self .quant_min , self .quant_max , scale ,
492- self .gamma , self .zeta , self .soft_quantize )
493- return repr_str
494-
495433
496434class AdaRoundLayer (nn .Module ):
497435 def __init__ (self , module : nn .Module ,
@@ -506,16 +444,17 @@ def __init__(self, module: nn.Module,
506444 if self .module .bias is not None :
507445 self .module .bias .requires_grad = False
508446
509- scale = self .quantizer .scale
510447 if self .quantizer .ch_axis != - 1 :
511448 new_shape = [1 ] * len (self .module .weight .shape )
512449 new_shape [self .quantizer .ch_axis ] = self .module .weight .shape [self .quantizer .ch_axis ]
513- scale = self .quantizer .scale .reshape (new_shape )
450+ self .quantizer .scale = self .quantizer .scale .reshape (new_shape )
451+ self .quantizer .zero_point = self .quantizer .zero_point .reshape (new_shape )
514452
453+ # Init rest.
454+ scale = self .quantizer .scale
515455 rest = self .module .weight / scale - (self .module .weight / scale ).floor ()
516456 rest = - torch .log ((reg .zeta - reg .gamma ) / (rest - reg .gamma ) - 1 )
517-
518- self .alpha = torch .nn .Parameter (rest .cuda (), True )
457+ self .alpha = torch .nn .Parameter (rest , True )
519458
520459 def forward (self , x ):
521460 weight = self .quantizer (self .module .weight , self .alpha )
@@ -529,6 +468,10 @@ def forward(self, x):
529468 else :
530469 raise RuntimeError ("Unsupported module type." )
531470
471+ if isinstance (self .module , (torch .nn .intrinsic .qat .ConvReLU2d ,
472+ torch .nn .intrinsic .qat .LinearReLU )):
473+ x = F .relu (x )
474+
532475 return x
533476
534477
@@ -541,7 +484,7 @@ def learning_alpha(in_tensor: torch.Tensor,
541484 batch_size : int ,
542485 max_iter : int ) -> torch .Tensor :
543486
544- optimizer = torch .optim .Adam ([ada_layer .alpha ], lr = learning_rate )
487+ optimizer = torch .optim .Adam ([ada_layer .alpha ])
545488
546489 for epoch in range (max_iter ):
547490 for idx in range (np .ceil (len (in_tensor ) / batch_size ).astype (int )):
@@ -560,33 +503,13 @@ def learning_alpha(in_tensor: torch.Tensor,
560503 loss .backward ()
561504 optimizer .step ()
562505
563- if epoch % 200 == 0 :
564- print ("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: "
565- "{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}" .format (epoch , loss , loss_p ,
566- loss_reg , ada_reg .beta ))
506+ if epoch % 100 == 0 :
507+ logger . info ("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: "
508+ "{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}" .format (epoch , loss , loss_p ,
509+ loss_reg , ada_reg .beta ))
567510 res = ada_reg .round_mask (ada_layer .alpha )
568- print ("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}" .format (
511+ logger . info ("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}" .format (
569512 loss ,
570513 res [res + 1e-4 >= 1.0 ].numel (), res [res <= 1e-4 ].numel (), torch .numel (res ),
571514 (res [res + 1e-4 >= 1.0 ].numel () + res [res <= 1e-4 ].numel ()) / torch .numel (res )))
572- return ada_layer .alpha
573-
574- @torch .no_grad ()
575- def round_to_nearset_quant (m : nn .Module , scale , zero_point , quant_min , quant_max , ch_axis ):
576- w = m .weight
577- if ch_axis != - 1 :
578- new_shape = [1 ] * len (w .shape )
579- new_shape [ch_axis ] = w .shape [ch_axis ]
580- scale = scale .reshape (new_shape )
581- zero_point = zero_point .reshape (new_shape )
582-
583- w = (w / scale ).round ()
584- w += zero_point
585- w = w .clamp (quant_min , quant_max )
586- w -= zero_point
587- w = w * scale
588-
589- return w
590-
591- if __name__ == "__main__" :
592- pass
515+ return ada_layer .alpha
0 commit comments