Skip to content

Commit 3567c06

Browse files
author
zhangqi3
committed
[Update] Deploy to TVM && TQT fake quantize.
1 parent 371cc4d commit 3567c06

File tree

15 files changed

+1817
-188
lines changed

15 files changed

+1817
-188
lines changed

mqbench/adaround.py

Lines changed: 69 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88
from torch.fx import GraphModule, Node
9+
from torch.quantization.observer import ObserverBase
10+
911

10-
from mqbench.observer import MinMaxObserver, ObserverBase
1112
from mqbench.utils import deepcopy_graphmodule
13+
from mqbench.utils.state import enable_quantization, disable_all
14+
from mqbench.utils.logger import logger
15+
1216

13-
_ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear, )
17+
__all__ = ['adaround']
18+
_ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear)
1419

1520

1621
def lp_norm(prediction, target, p=2.0):
@@ -26,6 +31,7 @@ def lp_norm(prediction, target, p=2.0):
2631
"""
2732
return (prediction - target).abs().pow(p).sum(1).mean()
2833

34+
2935
def _rectified_sigmoid(x, zeta, gamma):
3036
"""Function to generate rounding mask.
3137
@@ -39,60 +45,28 @@ def _rectified_sigmoid(x, zeta, gamma):
3945
"""
4046
return ((zeta - gamma) * torch.sigmoid(x) + gamma).clamp(0, 1)
4147

42-
def get_cali_samples(train_data_loader, num_samples, no_label=True):
43-
"""Generate sub-dataset for calibration.
44-
45-
Args:
46-
train_data_loader (torch.utils.data.DataLoader):
47-
num_samples (int):
48-
no_label (bool, optional): If the dataloader has no labels. Defaults to True.
4948

50-
Returns:
51-
torch.Tensor: Concatenated data matrix.
52-
"""
53-
cali_data_list = []
54-
if no_label:
55-
for batch_data in train_data_loader:
56-
cali_data_list.append(batch_data["image"])
57-
if len(cali_data_list) >= num_samples:
58-
break
59-
else:
60-
for batch_data, _ in train_data_loader:
61-
cali_data_list.append(batch_data)
62-
if len(cali_data_list) >= num_samples:
63-
break
64-
return torch.cat(cali_data_list, dim=0)[:num_samples].cpu()
65-
66-
def adaround(model: GraphModule, train_data, n_samples: int = 128,
67-
lr: float = 4e-3, batch_size: int = 128, max_iter: int = 8000,
68-
weight: float = 0.01, beta: float = 20, gamma: float = -0.1, zeta: float = 1.1,
69-
quant_min: int = -128, quant_max: int = 127, per_channel: bool = False):
49+
def adaround(model: GraphModule, cali_data,
50+
lr: float = 0.001, batch_size: int = 128, max_iter: int = 8000,
51+
weight: float = 0.01, beta: float = 20, gamma: float = -0.1, zeta: float = 1.1):
7052
"""Main function to run AdaRound on a given model.
7153
7254
Args:
73-
model (GraphModule):
74-
train_data (torch.utils.data.DataLoader):
75-
n_samples (int, optional): Defaults to 128.
76-
lr (float, optional): Defaults to 4e-3.
55+
model (GraphModule): Model to adaround.
56+
cali_data (torch.tensor): Stacked tensor.
57+
lr (float, optional): Defaults to 0.001.
7758
batch_size (int, optional): Defaults to 128.
7859
max_iter (int, optional): Defaults to 8000.
7960
weight (float, optional): Defaults to 0.01.
8061
beta (float, optional): Defaults to 20.
8162
gamma (float, optional): Defaults to -0.1.
8263
zeta (float, optional): Defaults to 1.1.
83-
quant_min (int, optional): Defaults to -128.
84-
quant_max (int, optional): Defaults to 127.
85-
per_channel (bool, optional): Defaults to False.
8664
8765
Returns:
8866
GraphModule: Modified copy of the given model.
8967
"""
90-
model.cpu()
91-
print("AdaRound: Quant-Range="
92-
"[{}, {}], Per-Channel={}".format(quant_min, quant_max, per_channel))
93-
94-
# sample data from training data
95-
cali_data = get_cali_samples(train_data, n_samples)
68+
device = cali_data.device
69+
model.to(device)
9670

9771
# apply rewritten deepcopy of GraphModule
9872
quant_model = deepcopy_graphmodule(model)
@@ -103,50 +77,33 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
10377
fp_observer_binding_dict = _insert_observer(model, "output")
10478
quant_observer_binding_dict = _insert_observer(quant_model, "input")
10579

106-
print("Record Outputs (by CPU) ...")
80+
logger.info("Record Outputs ...")
10781
# apply data to record output
82+
disable_all(model)
83+
enable_quantization(quant_model)
84+
10885
saver = FpOutputSaver(model, observer_binding_dict=fp_observer_binding_dict,
10986
input_data=cali_data)
11087

11188
# get layers for reconstruction
11289
modules = dict(quant_model.named_modules())
11390
quant_module_name_list = _get_quant_modules_by_topology(quant_model)
11491

115-
# TODO: more observer types / affine mode
116-
if per_channel:
117-
qscheme = torch.per_channel_symmetric
118-
ch_axis = 0
119-
else:
120-
qscheme = torch.per_tensor_symmetric
121-
ch_axis = -1
122-
123-
observer_type = MinMaxObserver.with_args(dtype=torch.qint8, quant_min=quant_min, quant_max=quant_max,
124-
reduce_range=False, qscheme=qscheme, ch_axis=ch_axis)
125-
126-
scale_dict = _init_weight_scale(quant_model, quant_observer_binding_dict.keys(), observer_type)
127-
12892
# disable gradient for all parameters
129-
for n, m in quant_model.named_modules():
130-
if hasattr(m, "weight"):
131-
m.weight.requires_grad = False
132-
if hasattr(m, "bias") and getattr(m, "bias") is not None:
133-
m.bias.requires_grad = False
134-
135-
quant_model.cuda()
136-
cali_data = cali_data.cuda()
93+
for p in quant_model.parameters():
94+
p.requires_grad = False
13795

13896
# learn the rounding mask for each layer
13997
for node_name in quant_module_name_list:
140-
print("===> Train for Layer: {}".format(node_name))
98+
logger.info("Adaround for Layer: {}".format(node_name))
14199
# get input and output tensors
142-
output_tensor = saver.get_result_by_name(node_name).cuda()
100+
output_tensor = saver.get_result_by_name(node_name).to(device)
143101
input_observer = modules[quant_observer_binding_dict[node_name].name]
144102
cur_node = _get_node_by_name(quant_model, node_name)
145103
if cur_node is not None:
146104
module = modules[cur_node.target]
147105
else:
148106
raise RuntimeError("Node not found in graph.")
149-
module.eval()
150107

151108
with _Recorder(input_observer):
152109
with torch.no_grad():
@@ -158,12 +115,14 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
158115
ada_reg_loss = AdaRoundReg(zeta=zeta, gamma=gamma, weight=weight,
159116
temp_anneal=temp_anneal, h_func=_rectified_sigmoid)
160117

161-
scale, zero_point = scale_dict[node_name]
162-
ada_quantizer = AdaRoundQuantizer(reg=ada_reg_loss, ch_axis=ch_axis,
163-
scale=scale, zero_point=zero_point,
164-
quant_min=quant_min, quant_max=quant_max)
118+
weight_fake_quant = module.weight_fake_quant
119+
ch_axis = weight_fake_quant.activation_post_process.ch_axis
120+
scale, zero_point = weight_fake_quant.activation_post_process.calculate_qparams()
121+
quant_min, quant_max = weight_fake_quant.activation_post_process._calculate_qmin_qmax()
122+
ada_quantizer = AdaRoundQuantizer(reg=ada_reg_loss, scale=scale, zero_point=zero_point,
123+
quant_min=quant_min, quant_max=quant_max, ch_axis=ch_axis)
165124

166-
ada_layer = AdaRoundLayer(module, ada_reg_loss, ada_quantizer).cuda()
125+
ada_layer = AdaRoundLayer(module, ada_reg_loss, ada_quantizer).to(device)
167126

168127
alpha = learning_alpha(input_tensor, output_tensor,
169128
ada_layer, ada_reg_loss, lr,
@@ -173,9 +132,26 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
173132
module.weight.data = ada_quantizer(module.weight, alpha)
174133
module.weight.requires_grad = False
175134

135+
_del_tensor_observer(quant_model, quant_observer_binding_dict)
136+
176137
return quant_model
177138

178139

140+
def _del_tensor_observer(gm: GraphModule, observer_binding_dict):
141+
modules = dict(gm.named_modules())
142+
nodes = list(gm.graph.nodes)
143+
# Quant model tensor observer insert in 'input' mode.
144+
for node in observer_binding_dict.values():
145+
delattr(gm, node.name)
146+
for _node in list(node.users.keys()):
147+
_node.args = node.args
148+
for node in observer_binding_dict.values():
149+
gm.graph.erase_node(node)
150+
151+
gm.recompile()
152+
gm.graph.lint()
153+
154+
179155
def _insert_observer(gm: GraphModule, insert_type="input"):
180156
"""Insert observers to record the input and output of target layers.
181157
@@ -260,7 +236,7 @@ class FpOutputSaver:
260236
@torch.no_grad()
261237
def __init__(self, fp_gm: GraphModule,
262238
observer_binding_dict: Dict[str, Node],
263-
save_loc="disk", root="./calibration",
239+
save_loc="disk", root="./cali_data_cache",
264240
input_data=None):
265241
"""
266242
Currently, there are two options provided to save floating point model
@@ -283,8 +259,8 @@ def __init__(self, fp_gm: GraphModule,
283259
self._data = dict()
284260

285261
if self.save_loc == "disk" and not os.path.exists(self.data_root):
286-
raise NotADirectoryError("The given path is not a folder."
287-
"Ensure you give the correct path.")
262+
logger.info('Save data on disk, create directory {}'.format(self.data_root))
263+
os.mkdir(self.data_root)
288264
saving_operation = self._disk_saving_operation \
289265
if self.save_loc == "disk" else self._gpu_saving_operation
290266

@@ -352,30 +328,6 @@ def _get_quant_modules_by_topology(gm: GraphModule):
352328
module_name_list.append(node.name)
353329
return module_name_list
354330

355-
def _init_weight_scale(gm: GraphModule, observed_module_list, observer_type: Callable):
356-
"""Simulate the fake quant modules to calculate scales and zero-points.
357-
358-
Args:
359-
gm (GraphModule):
360-
observed_module_list (list):
361-
observer_type (Callable):
362-
363-
Returns:
364-
dict:
365-
"""
366-
scale_dict = dict()
367-
modules = dict(gm.named_modules())
368-
369-
for name in observed_module_list:
370-
node = _get_node_by_name(gm, name)
371-
if node.op == "call_module":
372-
observer = observer_type()
373-
module = modules[node.target]
374-
weight = module.weight
375-
observer(weight)
376-
scale, zero_point = observer.calculate_qparams()
377-
scale_dict[name] = (scale.cuda().detach(), zero_point.cuda().detach())
378-
return scale_dict
379331

380332
def _get_node_by_name(gm: GraphModule, node_name: str):
381333
"""
@@ -446,8 +398,8 @@ def __call__(self, t):
446398

447399

448400
class AdaRoundQuantizer:
449-
def __init__(self, reg: AdaRoundReg, ch_axis: int,
450-
scale, zero_point, quant_min=-128, quant_max=127,
401+
def __init__(self, reg: AdaRoundReg, scale, zero_point,
402+
quant_min=-128, quant_max=127, ch_axis=-1,
451403
soft=True):
452404
self.quant_min = quant_min
453405
self.quant_max = quant_max
@@ -465,11 +417,6 @@ def __init__(self, reg: AdaRoundReg, ch_axis: int,
465417
def __call__(self, w, alpha):
466418
scale = self.scale
467419
zero_point = self.zero_point
468-
if self.ch_axis != -1:
469-
new_shape = [1] * len(w.shape)
470-
new_shape[self.ch_axis] = w.shape[self.ch_axis]
471-
scale = self.scale.reshape(new_shape)
472-
zero_point = self.zero_point.reshape(new_shape)
473420

474421
if self.soft_quantize:
475422
w = (w / scale).floor() + self.h_func(alpha, self.zeta, self.gamma)
@@ -483,15 +430,6 @@ def __call__(self, w, alpha):
483430
w = w * scale
484431
return w
485432

486-
def __repr__(self):
487-
scale = self.scale.item()
488-
if self.ch_axis != -1:
489-
scale = "per-channel scale of " + str(tuple(self.scale.shape))
490-
repr_str = "AdaRoundQuantizer(quant_min={}, quant_max={}, scale={}, " \
491-
"gamma={}, zeta={}, soft_quantize={})".format(self.quant_min, self.quant_max, scale,
492-
self.gamma, self.zeta, self.soft_quantize)
493-
return repr_str
494-
495433

496434
class AdaRoundLayer(nn.Module):
497435
def __init__(self, module: nn.Module,
@@ -506,16 +444,17 @@ def __init__(self, module: nn.Module,
506444
if self.module.bias is not None:
507445
self.module.bias.requires_grad = False
508446

509-
scale = self.quantizer.scale
510447
if self.quantizer.ch_axis != -1:
511448
new_shape = [1] * len(self.module.weight.shape)
512449
new_shape[self.quantizer.ch_axis] = self.module.weight.shape[self.quantizer.ch_axis]
513-
scale = self.quantizer.scale.reshape(new_shape)
450+
self.quantizer.scale = self.quantizer.scale.reshape(new_shape)
451+
self.quantizer.zero_point = self.quantizer.zero_point.reshape(new_shape)
514452

453+
# Init rest.
454+
scale = self.quantizer.scale
515455
rest = self.module.weight / scale - (self.module.weight / scale).floor()
516456
rest = -torch.log((reg.zeta - reg.gamma) / (rest - reg.gamma) - 1)
517-
518-
self.alpha = torch.nn.Parameter(rest.cuda(), True)
457+
self.alpha = torch.nn.Parameter(rest, True)
519458

520459
def forward(self, x):
521460
weight = self.quantizer(self.module.weight, self.alpha)
@@ -529,6 +468,10 @@ def forward(self, x):
529468
else:
530469
raise RuntimeError("Unsupported module type.")
531470

471+
if isinstance(self.module, (torch.nn.intrinsic.qat.ConvReLU2d,
472+
torch.nn.intrinsic.qat.LinearReLU)):
473+
x = F.relu(x)
474+
532475
return x
533476

534477

@@ -541,7 +484,7 @@ def learning_alpha(in_tensor: torch.Tensor,
541484
batch_size: int,
542485
max_iter: int) -> torch.Tensor:
543486

544-
optimizer = torch.optim.Adam([ada_layer.alpha], lr=learning_rate)
487+
optimizer = torch.optim.Adam([ada_layer.alpha])
545488

546489
for epoch in range(max_iter):
547490
for idx in range(np.ceil(len(in_tensor) / batch_size).astype(int)):
@@ -560,33 +503,13 @@ def learning_alpha(in_tensor: torch.Tensor,
560503
loss.backward()
561504
optimizer.step()
562505

563-
if epoch % 200 == 0:
564-
print("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: "
565-
"{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}".format(epoch, loss, loss_p,
566-
loss_reg, ada_reg.beta))
506+
if epoch % 100 == 0:
507+
logger.info("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: "
508+
"{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}".format(epoch, loss, loss_p,
509+
loss_reg, ada_reg.beta))
567510
res = ada_reg.round_mask(ada_layer.alpha)
568-
print("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}".format(
511+
logger.info("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}".format(
569512
loss,
570513
res[res + 1e-4 >= 1.0].numel(), res[res <= 1e-4].numel(), torch.numel(res),
571514
(res[res + 1e-4 >= 1.0].numel() + res[res <= 1e-4].numel()) / torch.numel(res)))
572-
return ada_layer.alpha
573-
574-
@torch.no_grad()
575-
def round_to_nearset_quant(m: nn.Module, scale, zero_point, quant_min, quant_max, ch_axis):
576-
w = m.weight
577-
if ch_axis != -1:
578-
new_shape = [1] * len(w.shape)
579-
new_shape[ch_axis] = w.shape[ch_axis]
580-
scale = scale.reshape(new_shape)
581-
zero_point = zero_point.reshape(new_shape)
582-
583-
w = (w / scale).round()
584-
w += zero_point
585-
w = w.clamp(quant_min, quant_max)
586-
w -= zero_point
587-
w = w * scale
588-
589-
return w
590-
591-
if __name__ == "__main__":
592-
pass
515+
return ada_layer.alpha

0 commit comments

Comments
 (0)