Skip to content

Commit 134ad00

Browse files
zhouyizhuang-megviidc3671
authored andcommitted
feat(classification): Update classification codebase to 1.0 (#71)
* fix(quant): modify elemwise add to fuse_add_relu (#64) * update classification codebase to 1.0 * Fix style according to reviews * update README and command line args * fix single gpu logic * use all reduce sum mode Co-authored-by: Dash Chen <[email protected]>
1 parent f5d9f0c commit 134ad00

File tree

9 files changed

+853
-683
lines changed

9 files changed

+853
-683
lines changed

official/vision/classification/resnet/README.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
| 模型 | top1 acc | top5 acc |
1010
| --- | --- | --- |
11-
| ResNet18 | 70.312 | 89.430 |
12-
| ResNet34 | 73.960 | 91.630 |
13-
| ResNet50 | 76.254 | 93.056 |
14-
| ResNet101 | 77.944 | 93.844 |
15-
| ResNet152 | 78.582 | 94.130 |
16-
| ResNeXt50 32x4d | 77.592 | 93.644 |
11+
| ResNet18 | 70.312 | 89.430 |
12+
| ResNet34 | 73.960 | 91.630 |
13+
| ResNet50 | 76.254 | 93.056 |
14+
| ResNet101 | 77.944 | 93.844 |
15+
| ResNet152 | 78.582 | 94.130 |
16+
| ResNeXt50 32x4d | 77.592 | 93.644 |
1717
| ResNeXt101 32x8d| 79.520 | 94.586 |
1818

1919
用户可以通过`megengine.hub`直接加载本目录下定义好的模型,例如:
@@ -64,20 +64,20 @@ python3 train.py --dataset-dir=/path/to/imagenet
6464
`train.py`提供了灵活的命令行选项,包括:
6565

6666
- `--data`, ImageNet数据集的根目录,默认`/data/datasets/imagenet`;
67-
- `--arch`, 需要训练的网络结构,默认`resnet18`
68-
- `--batch-size`,训练时每张卡采用的batch size, 默认32
69-
- `--ngpus`, 训练时采用的节点/gpu数量,默认1;当使用多张gpu时,将自动切换为分布式训练模式;
70-
- `--save`, 模型以及log存储的目录,默认`/data/models`;
71-
- `--learning-rate`, 训练时的初始学习率,默认0.0125,在分布式训练下,实际学习率等于初始学习率乘以节点/gpu数
72-
- `--epochs`, 训练多少个epoch,默认100
67+
- `--arch`, 需要训练的网络结构,默认`resnet50`
68+
- `--batch-size`,训练时每张卡采用的batch size, 默认64
69+
- `--ngpus`, 训练时每个节点采用的gpu数量,默认`None`,即使用全部gpu;当使用多张gpu时,将自动切换为分布式训练模式;
70+
- `--save`, 模型以及log存储的目录,默认`output`;
71+
- `--learning-rate`, 训练时的初始学习率,默认0.025,在分布式训练下,实际学习率等于初始学习率乘以总gpu数
72+
- `--epochs`, 训练多少个epoch,默认90
7373

7474
例如,可以通过以下命令在2块GPU上以64的batch大小训练一个`resnet50`的模型:
7575

7676
```bash
7777
python3 train.py --data /path/to/imagenet \
7878
--arch resnet50 \
79-
--batch-size 32 \
80-
--learning-rate 0.0125 \
79+
--batch-size 64 \
80+
--learning-rate 0.025 \
8181
--ngpus 2 \
8282
--save /path/to/save_dir
8383
```
@@ -95,9 +95,9 @@ python3 test.py --data=/path/to/imagenet --arch resnet50 --model /path/to/model
9595
`test.py`的命令行选项如下:
9696

9797
- `--data`,ImageNet数据集的根目录,默认`/data/datasets/imagenet`
98-
- `--arch`, 需要测试的网络结构,默认`resnet18`
98+
- `--arch`, 需要测试的网络结构,默认`resnet50`
9999
- `--model`, 需要测试的模型,默认使用官方预训练模型;
100-
- `--ngpus`, 用于测试的gpu数量,默认1
100+
- `--ngpus`, 用于测试的gpu数量,默认`None`
101101

102102
更多详细介绍可以通过运行`python3 test.py --help`查看。
103103

official/vision/classification/resnet/inference.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
import argparse
1010
import json
1111

12+
import model as resnet_model
13+
1214
import cv2
13-
import megengine as mge
15+
import numpy as np
16+
17+
import megengine
1418
import megengine.data.transform as T
1519
import megengine.functional as F
1620
import megengine.jit as jit
17-
import numpy as np
1821

19-
import model as M
22+
logging = megengine.logger.get_logger()
2023

2124

2225
def main():
@@ -26,9 +29,12 @@ def main():
2629
parser.add_argument("-i", "--image", default=None, type=str)
2730
args = parser.parse_args()
2831

29-
model = getattr(M, args.arch)(pretrained=(args.model is None))
30-
if args.model:
31-
state_dict = mge.load(args.model)
32+
model = resnet_model.__dict__[args.arch](pretrained=(args.model is None))
33+
if args.model is not None:
34+
logging.info("load from checkpoint %s", args.model)
35+
checkpoint = megengine.load(args.model)
36+
if "state_dict" in checkpoint:
37+
state_dict = checkpoint["state_dict"]
3238
model.load_state_dict(state_dict)
3339

3440
if args.image is None:
@@ -48,7 +54,6 @@ def main():
4854
]
4955
)
5056

51-
@jit.trace(symbolic=True)
5257
def infer_func(processed_img):
5358
model.eval()
5459
logits = model(processed_img)
@@ -58,7 +63,7 @@ def infer_func(processed_img):
5863
processed_img = transform.apply(image)[np.newaxis, :]
5964
probs = infer_func(processed_img)
6065

61-
top_probs, classes = F.top_k(probs, k=5, descending=True)
66+
top_probs, classes = F.topk(probs, k=5, descending=True)
6267

6368
with open("../../../assets/imagenet_class_info.json") as fp:
6469
imagenet_class_index = json.load(fp)

official/vision/classification/resnet/test.py

Lines changed: 138 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -7,85 +7,170 @@
77
# software distributed under the License is distributed on an
88
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
import argparse
10-
import multiprocessing as mp
10+
import bisect
11+
import multiprocessing
12+
import os
13+
import threading
1114
import time
1215

13-
import megengine as mge
14-
import megengine.data as data
15-
import megengine.data.transform as T
16-
import megengine.distributed as dist
17-
import megengine.functional as F
18-
import megengine.jit as jit
16+
import model as resnet_model
1917

20-
import model as M
18+
import megengine
19+
from megengine import data as data
20+
from megengine import distributed as dist
21+
from megengine import functional as F
22+
from megengine import jit as jit
23+
from megengine.data import transform as T
2124

22-
logger = mge.get_logger(__name__)
25+
logging = megengine.logger.get_logger()
2326

2427

2528
def main():
26-
parser = argparse.ArgumentParser()
27-
parser.add_argument("-a", "--arch", default="resnet18", type=str)
28-
parser.add_argument("-d", "--data", default=None, type=str)
29-
parser.add_argument("-m", "--model", default=None, type=str)
30-
31-
parser.add_argument("-n", "--ngpus", default=None, type=int)
32-
parser.add_argument("-w", "--workers", default=4, type=int)
33-
parser.add_argument("--report-freq", default=50, type=int)
34-
args = parser.parse_args()
29+
parser = argparse.ArgumentParser(description="MegEngine ImageNet Training")
30+
parser.add_argument("-d", "--data", metavar="DIR", help="path to imagenet dataset")
31+
parser.add_argument(
32+
"-a",
33+
"--arch",
34+
default="resnet50",
35+
help="model architecture (default: resnet50)",
36+
)
37+
parser.add_argument(
38+
"-n",
39+
"--ngpus",
40+
default=None,
41+
type=int,
42+
help="number of GPUs per node (default: None, use all available GPUs)",
43+
)
44+
parser.add_argument(
45+
"-m", "--model", metavar="PKL", default=None, help="path to model checkpoint"
46+
)
3547

36-
world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
48+
parser.add_argument("-j", "--workers", default=2, type=int)
49+
parser.add_argument(
50+
"-p",
51+
"--print-freq",
52+
default=20,
53+
type=int,
54+
metavar="N",
55+
help="print frequency (default: 10)",
56+
)
3757

38-
if world_size > 1:
39-
# start distributed training, dispatch sub-processes
40-
mp.set_start_method("spawn")
41-
processes = []
42-
for rank in range(world_size):
43-
p = mp.Process(target=worker, args=(rank, world_size, args))
44-
p.start()
45-
processes.append(p)
58+
parser.add_argument("--dist-addr", default="localhost")
59+
parser.add_argument("--dist-port", default=23456)
60+
parser.add_argument("--world-size", default=1)
61+
parser.add_argument("--rank", default=0)
4662

47-
for p in processes:
48-
p.join()
49-
else:
50-
worker(0, 1, args)
63+
args = parser.parse_args()
64+
65+
# create server if is master
66+
if args.rank <= 0:
67+
dist.Server(port=args.dist_port)
68+
69+
# get device count
70+
with multiprocessing.Pool(1) as pool:
71+
ngpus_per_node, _ = pool.map(megengine.get_device_count, ["gpu", "cpu"])
72+
if args.ngpus:
73+
ngpus_per_node = args.ngpus
74+
75+
# launch processes
76+
procs = []
77+
for local_rank in range(ngpus_per_node):
78+
p = multiprocessing.Process(
79+
target=worker,
80+
kwargs=dict(
81+
rank=args.rank * ngpus_per_node + local_rank,
82+
world_size=args.world_size * ngpus_per_node,
83+
ngpus_per_node=ngpus_per_node,
84+
args=args,
85+
),
86+
)
87+
p.start()
88+
procs.append(p)
5189

90+
# join processes
91+
for p in procs:
92+
p.join()
5293

53-
def worker(rank, world_size, args):
94+
95+
def worker(rank, world_size, ngpus_per_node, args):
5496
if world_size > 1:
55-
# Initialize distributed process group
56-
logger.info("init distributed process group {} / {}".format(rank, world_size))
97+
# init process group
5798
dist.init_process_group(
58-
master_ip="localhost",
59-
master_port=23456,
99+
master_ip=args.dist_addr,
100+
port=args.dist_port,
60101
world_size=world_size,
61102
rank=rank,
62-
dev=rank,
103+
device=rank % ngpus_per_node,
104+
backend="nccl",
105+
)
106+
logging.info(
107+
"init process group rank %d / %d", dist.get_rank(), dist.get_world_size()
63108
)
64109

65-
model = getattr(M, args.arch)(pretrained=(args.model is None))
110+
# build dataset
111+
_, valid_dataloader = build_dataset(args)
66112

67-
if args.model:
68-
logger.info("load weights from %s", args.model)
69-
model.load_state_dict(mge.load(args.model))
113+
# build model
114+
model = resnet_model.__dict__[args.arch](pretrained=args.model is None)
115+
if args.model is not None:
116+
logging.info("load from checkpoint %s", args.model)
117+
checkpoint = megengine.load(args.model)
118+
if "state_dict" in checkpoint:
119+
state_dict = checkpoint["state_dict"]
120+
model.load_state_dict(state_dict)
70121

71-
@jit.trace(symbolic=True)
72-
def valid_func(image, label):
73-
model.eval()
122+
def valid_step(image, label):
74123
logits = model(image)
75-
loss = F.cross_entropy_with_softmax(logits, label)
76-
acc1, acc5 = F.accuracy(logits, label, (1, 5))
77-
if dist.is_distributed(): # all_reduce_mean
78-
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
79-
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
80-
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
124+
loss = F.nn.cross_entropy(logits, label)
125+
acc1, acc5 = F.topk_accuracy(logits, label, topk=(1, 5))
126+
# calculate mean values
127+
if world_size > 1:
128+
loss = F.distributed.all_reduce_sum(loss) / world_size
129+
acc1 = F.distributed.all_reduce_sum(acc1) / world_size
130+
acc5 = F.distributed.all_reduce_sum(acc5) / world_size
81131
return loss, acc1, acc5
82132

83-
logger.info("preparing dataset..")
133+
model.eval()
134+
_, valid_acc1, valid_acc5 = valid(valid_step, valid_dataloader, args)
135+
logging.info(
136+
"Test Acc@1 %.3f, Acc@5 %.3f", valid_acc1, valid_acc5,
137+
)
138+
139+
140+
def valid(func, data_queue, args):
141+
objs = AverageMeter("Loss")
142+
top1 = AverageMeter("Acc@1")
143+
top5 = AverageMeter("Acc@5")
144+
clck = AverageMeter("Time")
145+
146+
t = time.time()
147+
for step, (image, label) in enumerate(data_queue):
148+
image = megengine.tensor(image, dtype="float32")
149+
label = megengine.tensor(label, dtype="int32")
150+
151+
n = image.shape[0]
152+
153+
loss, acc1, acc5 = func(image, label)
154+
155+
objs.update(loss.item(), n)
156+
top1.update(100 * acc1.item(), n)
157+
top5.update(100 * acc5.item(), n)
158+
clck.update(time.time() - t, n)
159+
t = time.time()
160+
161+
if step % args.print_freq == 0 and dist.get_rank() == 0:
162+
logging.info("Test step %d, %s %s %s %s", step, objs, top1, top5, clck)
163+
164+
return objs.avg, top1.avg, top5.avg
165+
166+
167+
def build_dataset(args):
168+
train_dataloader = None
84169
valid_dataset = data.dataset.ImageNet(args.data, train=False)
85170
valid_sampler = data.SequentialSampler(
86171
valid_dataset, batch_size=100, drop_last=False
87172
)
88-
valid_queue = data.DataLoader(
173+
valid_dataloader = data.DataLoader(
89174
valid_dataset,
90175
sampler=valid_sampler,
91176
transform=T.Compose(
@@ -100,42 +185,7 @@ def valid_func(image, label):
100185
),
101186
num_workers=args.workers,
102187
)
103-
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
104-
logger.info("Valid %.3f / %.3f", valid_acc, valid_acc5)
105-
106-
107-
def infer(model, data_queue, args, epoch=0):
108-
objs = AverageMeter("Loss")
109-
top1 = AverageMeter("Acc@1")
110-
top5 = AverageMeter("Acc@5")
111-
total_time = AverageMeter("Time")
112-
113-
t = time.time()
114-
for step, (image, label) in enumerate(data_queue):
115-
n = image.shape[0]
116-
image = image.astype("float32") # convert np.uint8 to float32
117-
label = label.astype("int32")
118-
119-
loss, acc1, acc5 = model(image, label)
120-
121-
objs.update(loss.numpy()[0], n)
122-
top1.update(100 * acc1.numpy()[0], n)
123-
top5.update(100 * acc5.numpy()[0], n)
124-
total_time.update(time.time() - t)
125-
t = time.time()
126-
127-
if step % args.report_freq == 0 and dist.get_rank() == 0:
128-
logger.info(
129-
"Epoch %d Step %d, %s %s %s %s",
130-
epoch,
131-
step,
132-
objs,
133-
top1,
134-
top5,
135-
total_time,
136-
)
137-
138-
return objs.avg, top1.avg, top5.avg
188+
return train_dataloader, valid_dataloader
139189

140190

141191
class AverageMeter:

0 commit comments

Comments
 (0)