Skip to content

Commit 1f6ad9c

Browse files
lixiangyin666Jianfeng Wang
authored andcommitted
feat(quantization): update quantization codebase to 1.0 (#79)
* feat(quantization): update quantization codebase to 1.0 * refactor(quantization): refactor ci by using shell script and change quantization code for format check
1 parent 25014fa commit 1f6ad9c

File tree

17 files changed

+458
-450
lines changed

17 files changed

+458
-450
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,4 @@ jobs:
3535
3636
# Runs a set of commands using the runners shell
3737
- name: Format check
38-
run: |
39-
export PYTHONPATH=$PWD:$PYTHONPATH
40-
41-
CHECK_VISION=official/vision/
42-
CHECK_NLP=official/nlp/
43-
pip install pylint==2.5.2
44-
pylint $CHECK_VISION $CHECK_NLP --rcfile=.pylintrc || pylint_ret=$?
45-
echo test, and deploy your project.
46-
if [ "$pylint_ret" ]; then
47-
exit $pylint_ret
48-
fi
49-
echo "All lint steps passed!"
50-
51-
pip3 install flake8==3.7.9
52-
flake8 official
53-
echo "All flake check passed!"
54-
55-
pip3 install isort==4.3.21
56-
isort --check-only -rc official
57-
echo "All isort check passed!"
38+
run: ./run_format_check.sh

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
*log*/
22
*.jpg
33
*.png
4-
*.txt
54

65
# compilation and distribution
76
__pycache__

official/quantization/README.md

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
| Model | top1 acc (float32) | FPS* (float32) | top1 acc (int8) | FPS* (int8) |
77
| --- | --- | --- | --- | --- |
8-
| ResNet18 | 69.824 | 10.5 | 69.754 | 16.3 |
9-
| ShufflenetV1 (1.5x) | 71.954 | 17.3 | 70.656 | 25.3 |
10-
| MobilenetV2 | 72.820 | 13.1 | 71.378 | 17.4 |
8+
| ResNet18 | 69.796 | 10.5 | 69.814 | 16.3 |
9+
| ShufflenetV1 (1.5x) | 71.948 | 17.3 | 70.806 | 25.3 |
10+
| MobilenetV2 | 72.808 | 13.1 | 71.228 | 17.4 |
1111

1212
**: FPS is measured on Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz, single 224x224 image*
1313

@@ -18,12 +18,12 @@
1818

1919
#### (Optional) Download Pretrained Models
2020
```
21-
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_normal_72820.pkl
22-
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_qat_71378.pkl
23-
wget https://data.megengine.org.cn/models/weights/resnet18_normal_69824.pkl
24-
wget https://data.megengine.org.cn/models/weights/resnet18_qat_69754.pkl
25-
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_normal_71954.pkl
26-
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_qat_70656.pkl
21+
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_normal_72808.pkl
22+
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_qat_71228.pkl
23+
wget https://data.megengine.org.cn/models/weights/resnet18_normal_69796.pkl
24+
wget https://data.megengine.org.cn/models/weights/resnet18_qat_69814.pkl
25+
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_normal_71948.pkl
26+
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_qat_70806.pkl
2727
```
2828

2929
## Quantization Aware Training (QAT)
@@ -44,7 +44,7 @@ for _ in range(...):
4444

4545
```python
4646
import megengine.quantization as Q
47-
import megengine.jit as jit
47+
from megengine.jit import trace
4848

4949
model = ...
5050

@@ -53,10 +53,11 @@ Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
5353
# real quant
5454
Q.quantize(model)
5555

56-
@jit.trace(symbolic=True):
56+
@trace(symbolic=True, capture_as_const=True)
5757
def inference_func(x):
5858
return model(x)
5959

60+
inference_func(x)
6061
inference_func.dump(...)
6162
```
6263

@@ -76,7 +77,7 @@ python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resne
7677

7778
## Step 2. Calibration
7879
```
79-
python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl --mode calibration
80+
python3 calibration.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl
8081
```
8182

8283
## Step 3. Test QAT model on ImageNet Testset
@@ -85,7 +86,7 @@ python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resne
8586
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode qat
8687
```
8788

88-
or testing in quantized mode, which uses only cpu for inference and takes longer time
89+
or testing in quantized mode
8990

9091
```
9192
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized -n 1

official/quantization/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# -*- coding: utf-8 -*-
2+
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3+
#
4+
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5+
#
6+
# Unless required by applicable law or agreed to in writing,
7+
# software distributed under the License is distributed on an
8+
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

official/quantization/calibration.py

Lines changed: 56 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,91 @@
66
# Unless required by applicable law or agreed to in writing,
77
# software distributed under the License is distributed on an
88
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)"""
9+
"""Finetune a pretrained fp32 with int8 post train quantization(calibration)"""
1010
import argparse
1111
import collections
12-
import multiprocessing as mp
1312
import numbers
1413
import os
15-
import bisect
1614
import time
1715

16+
# pylint: disable=import-error
17+
import models
18+
1819
import megengine as mge
1920
import megengine.data as data
2021
import megengine.data.transform as T
2122
import megengine.distributed as dist
2223
import megengine.functional as F
23-
import megengine.jit as jit
24-
import megengine.optimizer as optim
2524
import megengine.quantization as Q
26-
27-
import config
28-
import models
25+
from megengine.quantization.quantize import enable_observer, quantize, quantize_qat
2926

3027
logger = mge.get_logger(__name__)
31-
# from imagenet_nori_dataset import ImageNetNoriDataset
32-
from megengine.quantization.quantize import enable_observer, quantize, quantize_qat
28+
3329

3430
def main():
3531
parser = argparse.ArgumentParser()
3632
parser.add_argument("-a", "--arch", default="resnet18", type=str)
3733
parser.add_argument("-d", "--data", default=None, type=str)
3834
parser.add_argument("-s", "--save", default="/data/models", type=str)
39-
parser.add_argument("-c", "--checkpoint", default=None, type=str,
40-
help="pretrained model to finetune")
41-
42-
parser.add_argument("-m", "--mode", default="qat", type=str,
43-
choices=["normal", "qat", "quantized", "calibration"],
44-
help="Quantization Mode\n"
45-
"normal: no quantization, using float32\n"
46-
"qat: quantization aware training, simulate int8\n"
47-
"calibration: calibration\n"
48-
"quantized: convert mode to int8 quantized, inference only")
35+
parser.add_argument(
36+
"-c",
37+
"--checkpoint",
38+
default=None,
39+
type=str,
40+
help="pretrained model to finetune",
41+
)
4942

5043
parser.add_argument("-n", "--ngpus", default=None, type=int)
5144
parser.add_argument("-w", "--workers", default=4, type=int)
5245
parser.add_argument("--report-freq", default=50, type=int)
5346
args = parser.parse_args()
5447

55-
world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
56-
57-
if world_size > 1:
58-
# start distributed training, dispatch sub-processes
59-
mp.set_start_method("spawn")
60-
processes = []
61-
for rank in range(world_size):
62-
p = mp.Process(target=worker, args=(rank, world_size, args))
63-
p.start()
64-
processes.append(p)
65-
66-
for p in processes:
67-
p.join()
68-
else:
69-
worker(0, 1, args)
48+
world_size = (
49+
dist.helper.get_device_count_by_fork("gpu")
50+
if args.ngpus is None
51+
else args.ngpus
52+
)
53+
world_size = 1 if world_size == 0 else world_size
54+
if world_size != 1:
55+
logger.warning(
56+
"Calibration only supports single GPU now, %d provided", world_size
57+
)
58+
proc_func = dist.launcher(worker) if world_size > 1 else worker
59+
proc_func(world_size, args)
7060

7161

7262
def get_parameters(model, cfg):
7363
if isinstance(cfg.WEIGHT_DECAY, numbers.Number):
74-
return {"params": model.parameters(requires_grad=True),
75-
"weight_decay": cfg.WEIGHT_DECAY}
64+
return {
65+
"params": model.parameters(requires_grad=True),
66+
"weight_decay": cfg.WEIGHT_DECAY,
67+
}
7668

7769
groups = collections.defaultdict(list) # weight_decay -> List[param]
7870
for pname, p in model.named_parameters(requires_grad=True):
7971
wd = cfg.WEIGHT_DECAY(pname, p)
8072
groups[wd].append(p)
8173
groups = [
82-
{"params": params, "weight_decay": wd}
83-
for wd, params in groups.items()
74+
{"params": params, "weight_decay": wd} for wd, params in groups.items()
8475
] # List[{param, weight_decay}]
8576
return groups
8677

8778

88-
def worker(rank, world_size, args):
79+
def worker(world_size, args):
8980
# pylint: disable=too-many-statements
9081

82+
rank = dist.get_rank()
9183
if world_size > 1:
9284
# Initialize distributed process group
9385
logger.info("init distributed process group {} / {}".format(rank, world_size))
94-
dist.init_process_group(
95-
master_ip="localhost",
96-
master_port=23456,
97-
world_size=world_size,
98-
rank=rank,
99-
dev=rank,
100-
)
10186

102-
save_dir = os.path.join(args.save, args.arch + "." + args.mode)
87+
save_dir = os.path.join(args.save, args.arch + "." + "calibration")
10388
if not os.path.exists(save_dir):
10489
os.makedirs(save_dir, exist_ok=True)
10590
mge.set_log_file(os.path.join(save_dir, "log.txt"))
10691

10792
model = models.__dict__[args.arch]()
108-
cfg = config.get_finetune_config(args.arch)
10993

110-
cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training
111-
total_batch_size = cfg.BATCH_SIZE * world_size
112-
steps_per_epoch = 1280000 // total_batch_size
113-
total_steps = steps_per_epoch * cfg.EPOCHS
114-
11594
# load calibration model
11695
assert args.checkpoint
11796
logger.info("Load pretrained weights from %s", args.checkpoint)
@@ -121,70 +100,64 @@ def worker(rank, world_size, args):
121100

122101
# Build valid datasets
123102
valid_dataset = data.dataset.ImageNet(args.data, train=False)
124-
# valid_dataset = ImageNetNoriDataset(args.data)
125103
valid_sampler = data.SequentialSampler(
126104
valid_dataset, batch_size=100, drop_last=False
127105
)
128106
valid_queue = data.DataLoader(
129107
valid_dataset,
130108
sampler=valid_sampler,
131109
transform=T.Compose(
132-
[
133-
T.Resize(256),
134-
T.CenterCrop(224),
135-
T.Normalize(mean=128),
136-
T.ToMode("CHW"),
137-
]
110+
[T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")]
138111
),
139112
num_workers=args.workers,
140113
)
141114

142115
# calibration
143116
model.fc.disable_quantize()
144117
model = quantize_qat(model, qconfig=Q.calibration_qconfig)
145-
118+
146119
# calculate scale
147-
@jit.trace(symbolic=True)
148120
def calculate_scale(image, label):
149121
model.eval()
150122
enable_observer(model)
151123
logits = model(image)
152-
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
153-
acc1, acc5 = F.accuracy(logits, label, (1, 5))
124+
loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
125+
acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
154126
if dist.is_distributed(): # all_reduce_mean
155-
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
156-
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
157-
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
127+
loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
128+
acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
129+
acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
158130
return loss, acc1, acc5
159-
160-
# model.fc.disable_quantize()
131+
161132
infer(calculate_scale, valid_queue, args)
162133

163134
# quantized
164135
model = quantize(model)
165136

166137
# eval quantized model
167-
@jit.trace(symbolic=True)
168138
def eval_func(image, label):
169139
model.eval()
170140
logits = model(image)
171-
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
172-
acc1, acc5 = F.accuracy(logits, label, (1, 5))
141+
loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
142+
acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
173143
if dist.is_distributed(): # all_reduce_mean
174-
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
175-
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
176-
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
144+
loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
145+
acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
146+
acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
177147
return loss, acc1, acc5
178-
148+
179149
_, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args)
180150
logger.info("TEST %f, %f", valid_acc, valid_acc5)
181151

182152
# save quantized model
183153
mge.save(
184154
{"step": -1, "state_dict": model.state_dict()},
185-
os.path.join(save_dir, "checkpoint-calibration.pkl")
155+
os.path.join(save_dir, "checkpoint-calibration.pkl"),
156+
)
157+
logger.info(
158+
"save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl"))
186159
)
187-
logger.info("save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl")))
160+
188161

189162
def infer(model, data_queue, args):
190163
objs = AverageMeter("Loss")
@@ -195,8 +168,8 @@ def infer(model, data_queue, args):
195168
t = time.time()
196169
for step, (image, label) in enumerate(data_queue):
197170
n = image.shape[0]
198-
image = image.astype("float32") # convert np.uint8 to float32
199-
label = label.astype("int32")
171+
image = mge.tensor(image, dtype="float32")
172+
label = mge.tensor(label, dtype="int32")
200173

201174
loss, acc1, acc5 = model(image, label)
202175

@@ -207,9 +180,8 @@ def infer(model, data_queue, args):
207180
t = time.time()
208181

209182
if step % args.report_freq == 0 and dist.get_rank() == 0:
210-
logger.info("Step %d, %s %s %s %s",
211-
step, objs, top1, top5, total_time)
212-
183+
logger.info("Step %d, %s %s %s %s", step, objs, top1, top5, total_time)
184+
213185
# break
214186
if step == args.report_freq:
215187
break

0 commit comments

Comments
 (0)