Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 14553a0

Browse files
DominikaJedynakBartlomiej GawrychBartlomiej Gawrych
authored
Added ALBERT v2 quantization with INC example (#1591)
* Add quantization to QA scripts * fix * Remove quantize bool field * Fix electra large accuracy * Update mkldnn to onednn * Accuracy fix * Add sphinx to dev requirments * remove print * change quantize_mode to proper one * fix round_to argument * Albert example Co-authored-by: Bartlomiej Gawrych <[email protected]> Co-authored-by: Bartlomiej Gawrych <[email protected]>
1 parent fecd3e1 commit 14553a0

File tree

6 files changed

+1103
-14
lines changed

6 files changed

+1103
-14
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
version: 1.0
2+
3+
model:
4+
name: albert_base_v2
5+
framework: mxnet
6+
7+
tuning:
8+
strategy:
9+
name: mycustom
10+
accuracy_criterion:
11+
relative: 0.02
12+
exit_policy:
13+
timeout: 0
14+
max_trials: 1000
15+
random_seed: 9527
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import copy
2+
import numpy as np
3+
from collections import OrderedDict
4+
from neural_compressor.strategy.strategy import TuneStrategy, strategy_registry
5+
6+
plot_operator_influence = True
7+
8+
def calc_approx_error(expected_tensor: np.ndarray, observed_tensor: np.ndarray) -> float:
9+
'''
10+
Calculating relative error for one tensor
11+
'''
12+
error = observed_tensor - expected_tensor
13+
absolute_error = np.abs(error)
14+
mean_absolute_error = absolute_error.mean()
15+
mean_expected_value = np.abs(expected_tensor).mean()
16+
error = mean_absolute_error / mean_expected_value
17+
return error
18+
19+
20+
def get_approx_errors(expected_tensors, observed_tensors):
21+
'''
22+
Calculating relative error for multiple tensors: Dict[tensors_name: str, tensor: np.ndarray]
23+
'''
24+
errors = {}
25+
for node_name in observed_tensors.keys():
26+
expected_tensor = expected_tensors[node_name][node_name]
27+
observed_tensor = observed_tensors[node_name][node_name]
28+
errors[node_name] = calc_approx_error(expected_tensor, observed_tensor)
29+
return errors
30+
31+
32+
@strategy_registry
33+
class MyCustomTuneStrategy(TuneStrategy):
34+
'''INC Custom strategy definition'''
35+
def __init__(self, model, conf, q_dataloader, q_func=None,
36+
eval_dataloader=None, eval_func=None, dicts=None, q_hooks=None):
37+
super().__init__(
38+
model,
39+
conf,
40+
q_dataloader,
41+
q_func,
42+
eval_dataloader,
43+
eval_func,
44+
dicts,
45+
q_hooks)
46+
47+
48+
def get_qtensors(self, quant_cfg, node_list):
49+
'''
50+
Generating quantized model based on configuration and capturing intermediate tensors
51+
'''
52+
qmodel = self.adaptor.quantize(quant_cfg, self.model, self.calib_dataloader)
53+
tensors = self.adaptor.inspect_tensor(qmodel, self.calib_dataloader, node_list, [1]) # 1 is a batch index
54+
return tensors['activation'][0] # we need to specify that we want activation (layer output) because INC stores also weight tensors
55+
# 0 is the first batch
56+
def next_tune_cfg(self):
57+
FALLBACK_DTYPE = 'fp32'
58+
59+
# creating base configuration - all nodes are quantized and calibrated with minmax algorithm
60+
best_cfg = {}
61+
best_cfg['calib_iteration'] = int(self.calib_iter[0]) # number of batches for calibration
62+
best_cfg['calib_sampling_size'] = int(self.calib_sampling_size[0]) # number of samples for calibration (multiplicity of batch)
63+
nodes_cfg = OrderedDict()
64+
nodes_cfg_idx = {}
65+
for node_key, cfgs in self.opwise_tune_cfgs.items():
66+
for i, cfg in enumerate(cfgs):
67+
if cfg['activation']['algorithm'] == 'minmax':
68+
nodes_cfg_idx[node_key] = i
69+
break
70+
nodes_cfg[node_key] = cfg
71+
best_cfg['op'] = nodes_cfg
72+
73+
yield best_cfg
74+
75+
# If fully quantized model does not meet the requirements, we proceed to exclude some nodes
76+
77+
# Collecting tensors from the original model - expected tensors
78+
node_list = [op_name for (op_name, op_type) in best_cfg['op'].keys()]
79+
f32_tensors = self.adaptor.inspect_tensor(self.model, self.calib_dataloader, node_list, [1])
80+
f32_tensors = f32_tensors['activation'][0]
81+
82+
# Collecting tensors from the fully quantized model
83+
q_tensors = self.get_qtensors(best_cfg, node_list)
84+
approx_errors = get_approx_errors(f32_tensors, q_tensors)
85+
86+
# best_cfg['op'] is an OrderedDict, which order of elements should correspond to their
87+
# order in the computational graph
88+
for node_key, cfg in best_cfg['op'].items():
89+
# Node's key in INC is its name + its operator
90+
node_name, node_op = node_key
91+
# Checking what configuration options are available for this particular node
92+
capabilities = self.opwise_tune_space[node_key]['activation']['dtype']
93+
# If a particular node can be excluded from quanrtization ('fp32' in capabilities)
94+
# and current error is bigger than threshold value, we check what accuracy improvement
95+
# would be achieved by this exclusion
96+
if FALLBACK_DTYPE in capabilities and approx_errors[node_name] > 0.06:
97+
original_dtype = cfg['activation']['dtype']
98+
cfg['activation']['dtype'] = FALLBACK_DTYPE # Exclude the node from quantization
99+
100+
# Collecting tensors for a new configuration with the current node excluded
101+
q_tensors = self.get_qtensors(best_cfg, node_list)
102+
# Calculating errors for the new configuration
103+
new_approx_errors = get_approx_errors(f32_tensors, q_tensors)
104+
# Calculating error differences for every node in a model
105+
err_diffs = {}
106+
for tensor_node_name in new_approx_errors.keys():
107+
diff = approx_errors[tensor_node_name] - new_approx_errors[tensor_node_name]
108+
err_diffs[tensor_node_name] = diff
109+
err_diffs_arr = np.array(list(err_diffs.values()))
110+
111+
# If the sum of errors on the following layers is greater than the threshold value we
112+
# keep the node excluded
113+
threshold_sum_error_layers = err_diffs_arr.size * 0.01
114+
if err_diffs_arr.sum() >= threshold_sum_error_layers:
115+
before = approx_errors
116+
after = approx_errors.copy()
117+
after.update(new_approx_errors)
118+
if plot_operator_influence:
119+
import matplotlib.pyplot as plt
120+
plt.figure()
121+
plt.plot(before.values(), marker='o', markersize=2.5, label='Before')
122+
plt.plot(after.values(), marker='o', markersize=2.5, label='After')
123+
plt.ylabel('Relative error')
124+
plt.xlabel('Layer')
125+
plt.legend()
126+
plt.savefig(f'{node_name}_error.png')
127+
128+
approx_errors.update(new_approx_errors)
129+
nodes_cfg_idx.pop(node_key) # Mark node as not quantizable
130+
else:
131+
cfg['activation']['dtype'] = original_dtype
132+
133+
yield best_cfg
134+
135+
# Choosing calibration algorithm (kl or minmax) for every node which was not excluded from quantization
136+
for cfg in self.bayesian_configurations(best_cfg, nodes_cfg_idx):
137+
yield cfg
138+
139+
def bayesian_params_to_tune_configs(self, params):
140+
'''
141+
Creating configuration from params - changing configurations' indexes for real configurations
142+
'''
143+
node_cfgs = {}
144+
for node_key, configs in self.opwise_quant_cfgs.items():
145+
if node_key in params:
146+
value = int(params[node_key])
147+
value = min(value, len(configs) - 1)
148+
node_cfgs[node_key] = copy.deepcopy(configs[value])
149+
return node_cfgs
150+
151+
def bayesian_configurations(self, cfg_base, params_base):
152+
from neural_compressor.strategy.bayesian import BayesianOptimization
153+
154+
# For each node we specify the possible range of values (we treat them as a configurations' index)
155+
pbounds = {}
156+
for node_key, configs in self.opwise_quant_cfgs.items():
157+
if node_key in params_base and len(configs) > 1:
158+
pbounds[node_key] = (0, len(configs))
159+
160+
cfg = copy.deepcopy(cfg_base)
161+
if len(pbounds) == 0: # if there is nothing to be optimized, we finish
162+
cfg['op'].update(self.bayesian_params_to_tune_configs(params_base))
163+
return
164+
165+
bayes_opt = BayesianOptimization(pbounds=pbounds, random_seed=self.cfg.tuning.random_seed)
166+
bayes_opt._space.register(params_base, self.last_tune_result[0]) # registering the outcome of current configuration
167+
while True:
168+
# Generating next configuration
169+
params = bayes_opt.gen_next_params()
170+
cfg['op'].update(self.bayesian_params_to_tune_configs(params))
171+
yield cfg
172+
try:
173+
# Registering the outcome
174+
bayes_opt._space.register(params, self.last_tune_result[0])
175+
except KeyError:
176+
pass

scripts/question_answering/models.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1,
180180
self.answerable_scores.add(nn.Dense(2, flatten=False,
181181
weight_initializer=weight_initializer,
182182
bias_initializer=bias_initializer))
183+
self.quantized_backbone = None
183184

184185
def get_start_logits(self, contextual_embedding, p_mask):
185186
"""
@@ -287,10 +288,14 @@ def forward(self, tokens, token_types, valid_length, p_mask, start_position):
287288
Shape (batch_size, sequence_length)
288289
answerable_logits
289290
"""
291+
backbone_net = self.backbone
292+
if self.quantized_backbone != None:
293+
backbone_net = self.quantized_backbone
294+
290295
if self.use_segmentation:
291-
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
296+
contextual_embeddings = backbone_net(tokens, token_types, valid_length)
292297
else:
293-
contextual_embeddings = self.backbone(tokens, valid_length)
298+
contextual_embeddings = backbone_net(tokens, valid_length)
294299
start_logits = self.get_start_logits(contextual_embeddings, p_mask)
295300
end_logits = self.get_end_logits(contextual_embeddings,
296301
np.expand_dims(start_position, axis=1),
@@ -337,11 +342,16 @@ def inference(self, tokens, token_types, valid_length, p_mask,
337342
The answerable logits. Here 0 --> answerable and 1 --> not answerable.
338343
Shape (batch_size, sequence_length, 2)
339344
"""
345+
backbone_net = self.backbone
346+
if self.quantized_backbone != None:
347+
backbone_net = self.quantized_backbone
348+
340349
# Shape (batch_size, sequence_length, C)
341350
if self.use_segmentation:
342-
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
351+
contextual_embeddings = backbone_net(tokens, token_types, valid_length)
343352
else:
344-
contextual_embeddings = self.backbone(tokens, valid_length)
353+
contextual_embeddings = backbone_net(tokens, valid_length)
354+
345355
start_logits = self.get_start_logits(contextual_embeddings, p_mask)
346356
# The shape of start_top_index will be (..., start_top_n)
347357
start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,

0 commit comments

Comments
 (0)