@@ -146,3 +146,75 @@ def run(global_rank,
146146 flush = True )
147147
148148 dev .PrintTimeProfiling ()
149+
150+ if __name__ == '__main__' :
151+ parser = argparse .ArgumentParser (
152+ description = 'Training using the autograd and graph.' )
153+ parser .add_argument (
154+ 'model' ,
155+ choices = ['cnn' , 'resnet' , 'xceptionnet' , 'mlp' , 'alexnet' , 'candidiasisnet' ],
156+ default = 'candidiasisnet' )
157+ parser .add_argument ('data' ,
158+ choices = ['mnist' , 'cifar10' , 'cifar100' , 'candidiasis' ],
159+ default = 'candidiasis' )
160+ parser .add_argument ('-p' ,
161+ choices = ['float32' , 'float16' ],
162+ default = 'float32' ,
163+ dest = 'precision' )
164+ parser .add_argument ('-m' ,
165+ '--max-epoch' ,
166+ default = 100 ,
167+ type = int ,
168+ help = 'maximum epochs' ,
169+ dest = 'max_epoch' )
170+ parser .add_argument ('-b' ,
171+ '--batch-size' ,
172+ default = 64 ,
173+ type = int ,
174+ help = 'batch size' ,
175+ dest = 'batch_size' )
176+ parser .add_argument ('-l' ,
177+ '--learning-rate' ,
178+ default = 0.005 ,
179+ type = float ,
180+ help = 'initial learning rate' ,
181+ dest = 'lr' )
182+ parser .add_argument ('-i' ,
183+ '--device-id' ,
184+ default = 0 ,
185+ type = int ,
186+ help = 'which GPU to use' ,
187+ dest = 'device_id' )
188+ parser .add_argument ('-g' ,
189+ '--disable-graph' ,
190+ default = 'True' ,
191+ action = 'store_false' ,
192+ help = 'disable graph' ,
193+ dest = 'graph' )
194+ parser .add_argument ('-v' ,
195+ '--log-verbosity' ,
196+ default = 0 ,
197+ type = int ,
198+ help = 'logging verbosity' ,
199+ dest = 'verbosity' )
200+ parser .add_argument ('-dir' ,
201+ '--dir-path' ,
202+ type = str ,
203+ help = 'the directory to store the candidiasis dataset' ,
204+ dest = 'dir_path' )
205+
206+ args = parser .parse_args ()
207+
208+ sgd = opt .SGD (lr = args .lr , momentum = 0.9 , weight_decay = 1e-5 , dtype = singa_dtype [args .precision ])
209+ run (0 ,
210+ 1 ,
211+ args .device_id ,
212+ args .max_epoch ,
213+ args .batch_size ,
214+ args .model ,
215+ args .data ,
216+ sgd ,
217+ args .graph ,
218+ args .verbosity ,
219+ precision = args .precision ,
220+ dir_path = args .dir_path )
0 commit comments