77# software distributed under the License is distributed on an
88# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99import argparse
10- import multiprocessing as mp
10+ import bisect
11+ import multiprocessing
12+ import os
13+ import threading
1114import time
1215
13- import megengine as mge
14- import megengine .data as data
15- import megengine .data .transform as T
16- import megengine .distributed as dist
17- import megengine .functional as F
18- import megengine .jit as jit
16+ import model as resnet_model
1917
20- import model as M
18+ import megengine
19+ from megengine import data as data
20+ from megengine import distributed as dist
21+ from megengine import functional as F
22+ from megengine import jit as jit
23+ from megengine .data import transform as T
2124
22- logger = mge . get_logger (__name__ )
25+ logging = megengine . logger . get_logger ()
2326
2427
2528def main ():
26- parser = argparse .ArgumentParser ()
27- parser .add_argument ("-a" , "--arch" , default = "resnet18" , type = str )
28- parser .add_argument ("-d" , "--data" , default = None , type = str )
29- parser .add_argument ("-m" , "--model" , default = None , type = str )
30-
31- parser .add_argument ("-n" , "--ngpus" , default = None , type = int )
32- parser .add_argument ("-w" , "--workers" , default = 4 , type = int )
33- parser .add_argument ("--report-freq" , default = 50 , type = int )
34- args = parser .parse_args ()
29+ parser = argparse .ArgumentParser (description = "MegEngine ImageNet Training" )
30+ parser .add_argument ("-d" , "--data" , metavar = "DIR" , help = "path to imagenet dataset" )
31+ parser .add_argument (
32+ "-a" ,
33+ "--arch" ,
34+ default = "resnet50" ,
35+ help = "model architecture (default: resnet50)" ,
36+ )
37+ parser .add_argument (
38+ "-n" ,
39+ "--ngpus" ,
40+ default = None ,
41+ type = int ,
42+ help = "number of GPUs per node (default: None, use all available GPUs)" ,
43+ )
44+ parser .add_argument (
45+ "-m" , "--model" , metavar = "PKL" , default = None , help = "path to model checkpoint"
46+ )
3547
36- world_size = mge .get_device_count ("gpu" ) if args .ngpus is None else args .ngpus
48+ parser .add_argument ("-j" , "--workers" , default = 2 , type = int )
49+ parser .add_argument (
50+ "-p" ,
51+ "--print-freq" ,
52+ default = 20 ,
53+ type = int ,
54+ metavar = "N" ,
55+ help = "print frequency (default: 10)" ,
56+ )
3757
38- if world_size > 1 :
39- # start distributed training, dispatch sub-processes
40- mp .set_start_method ("spawn" )
41- processes = []
42- for rank in range (world_size ):
43- p = mp .Process (target = worker , args = (rank , world_size , args ))
44- p .start ()
45- processes .append (p )
58+ parser .add_argument ("--dist-addr" , default = "localhost" )
59+ parser .add_argument ("--dist-port" , default = 23456 )
60+ parser .add_argument ("--world-size" , default = 1 )
61+ parser .add_argument ("--rank" , default = 0 )
4662
47- for p in processes :
48- p .join ()
49- else :
50- worker (0 , 1 , args )
63+ args = parser .parse_args ()
64+
65+ # create server if is master
66+ if args .rank <= 0 :
67+ dist .Server (port = args .dist_port )
68+
69+ # get device count
70+ with multiprocessing .Pool (1 ) as pool :
71+ ngpus_per_node , _ = pool .map (megengine .get_device_count , ["gpu" , "cpu" ])
72+ if args .ngpus :
73+ ngpus_per_node = args .ngpus
74+
75+ # launch processes
76+ procs = []
77+ for local_rank in range (ngpus_per_node ):
78+ p = multiprocessing .Process (
79+ target = worker ,
80+ kwargs = dict (
81+ rank = args .rank * ngpus_per_node + local_rank ,
82+ world_size = args .world_size * ngpus_per_node ,
83+ ngpus_per_node = ngpus_per_node ,
84+ args = args ,
85+ ),
86+ )
87+ p .start ()
88+ procs .append (p )
5189
90+ # join processes
91+ for p in procs :
92+ p .join ()
5293
53- def worker (rank , world_size , args ):
94+
95+ def worker (rank , world_size , ngpus_per_node , args ):
5496 if world_size > 1 :
55- # Initialize distributed process group
56- logger .info ("init distributed process group {} / {}" .format (rank , world_size ))
97+ # init process group
5798 dist .init_process_group (
58- master_ip = "localhost" ,
59- master_port = 23456 ,
99+ master_ip = args . dist_addr ,
100+ port = args . dist_port ,
60101 world_size = world_size ,
61102 rank = rank ,
62- dev = rank ,
103+ device = rank % ngpus_per_node ,
104+ backend = "nccl" ,
105+ )
106+ logging .info (
107+ "init process group rank %d / %d" , dist .get_rank (), dist .get_world_size ()
63108 )
64109
65- model = getattr (M , args .arch )(pretrained = (args .model is None ))
110+ # build dataset
111+ _ , valid_dataloader = build_dataset (args )
66112
67- if args .model :
68- logger .info ("load weights from %s" , args .model )
69- model .load_state_dict (mge .load (args .model ))
113+ # build model
114+ model = resnet_model .__dict__ [args .arch ](pretrained = args .model is None )
115+ if args .model is not None :
116+ logging .info ("load from checkpoint %s" , args .model )
117+ checkpoint = megengine .load (args .model )
118+ if "state_dict" in checkpoint :
119+ state_dict = checkpoint ["state_dict" ]
120+ model .load_state_dict (state_dict )
70121
71- @jit .trace (symbolic = True )
72- def valid_func (image , label ):
73- model .eval ()
122+ def valid_step (image , label ):
74123 logits = model (image )
75- loss = F .cross_entropy_with_softmax (logits , label )
76- acc1 , acc5 = F .accuracy (logits , label , (1 , 5 ))
77- if dist .is_distributed (): # all_reduce_mean
78- loss = dist .all_reduce_sum (loss , "valid_loss" ) / dist .get_world_size ()
79- acc1 = dist .all_reduce_sum (acc1 , "valid_acc1" ) / dist .get_world_size ()
80- acc5 = dist .all_reduce_sum (acc5 , "valid_acc5" ) / dist .get_world_size ()
124+ loss = F .nn .cross_entropy (logits , label )
125+ acc1 , acc5 = F .topk_accuracy (logits , label , topk = (1 , 5 ))
126+ # calculate mean values
127+ if world_size > 1 :
128+ loss = F .distributed .all_reduce_sum (loss ) / world_size
129+ acc1 = F .distributed .all_reduce_sum (acc1 ) / world_size
130+ acc5 = F .distributed .all_reduce_sum (acc5 ) / world_size
81131 return loss , acc1 , acc5
82132
83- logger .info ("preparing dataset.." )
133+ model .eval ()
134+ _ , valid_acc1 , valid_acc5 = valid (valid_step , valid_dataloader , args )
135+ logging .info (
136+ "Test Acc@1 %.3f, Acc@5 %.3f" , valid_acc1 , valid_acc5 ,
137+ )
138+
139+
140+ def valid (func , data_queue , args ):
141+ objs = AverageMeter ("Loss" )
142+ top1 = AverageMeter ("Acc@1" )
143+ top5 = AverageMeter ("Acc@5" )
144+ clck = AverageMeter ("Time" )
145+
146+ t = time .time ()
147+ for step , (image , label ) in enumerate (data_queue ):
148+ image = megengine .tensor (image , dtype = "float32" )
149+ label = megengine .tensor (label , dtype = "int32" )
150+
151+ n = image .shape [0 ]
152+
153+ loss , acc1 , acc5 = func (image , label )
154+
155+ objs .update (loss .item (), n )
156+ top1 .update (100 * acc1 .item (), n )
157+ top5 .update (100 * acc5 .item (), n )
158+ clck .update (time .time () - t , n )
159+ t = time .time ()
160+
161+ if step % args .print_freq == 0 and dist .get_rank () == 0 :
162+ logging .info ("Test step %d, %s %s %s %s" , step , objs , top1 , top5 , clck )
163+
164+ return objs .avg , top1 .avg , top5 .avg
165+
166+
167+ def build_dataset (args ):
168+ train_dataloader = None
84169 valid_dataset = data .dataset .ImageNet (args .data , train = False )
85170 valid_sampler = data .SequentialSampler (
86171 valid_dataset , batch_size = 100 , drop_last = False
87172 )
88- valid_queue = data .DataLoader (
173+ valid_dataloader = data .DataLoader (
89174 valid_dataset ,
90175 sampler = valid_sampler ,
91176 transform = T .Compose (
@@ -100,42 +185,7 @@ def valid_func(image, label):
100185 ),
101186 num_workers = args .workers ,
102187 )
103- _ , valid_acc , valid_acc5 = infer (valid_func , valid_queue , args )
104- logger .info ("Valid %.3f / %.3f" , valid_acc , valid_acc5 )
105-
106-
107- def infer (model , data_queue , args , epoch = 0 ):
108- objs = AverageMeter ("Loss" )
109- top1 = AverageMeter ("Acc@1" )
110- top5 = AverageMeter ("Acc@5" )
111- total_time = AverageMeter ("Time" )
112-
113- t = time .time ()
114- for step , (image , label ) in enumerate (data_queue ):
115- n = image .shape [0 ]
116- image = image .astype ("float32" ) # convert np.uint8 to float32
117- label = label .astype ("int32" )
118-
119- loss , acc1 , acc5 = model (image , label )
120-
121- objs .update (loss .numpy ()[0 ], n )
122- top1 .update (100 * acc1 .numpy ()[0 ], n )
123- top5 .update (100 * acc5 .numpy ()[0 ], n )
124- total_time .update (time .time () - t )
125- t = time .time ()
126-
127- if step % args .report_freq == 0 and dist .get_rank () == 0 :
128- logger .info (
129- "Epoch %d Step %d, %s %s %s %s" ,
130- epoch ,
131- step ,
132- objs ,
133- top1 ,
134- top5 ,
135- total_time ,
136- )
137-
138- return objs .avg , top1 .avg , top5 .avg
188+ return train_dataloader , valid_dataloader
139189
140190
141191class AverageMeter :
0 commit comments