@@ -161,3 +161,144 @@ def data_partition(dataset_x, dataset_y, global_rank, world_size):
161161 idx_start = global_rank * data_per_rank
162162 idx_end = (global_rank + 1 ) * data_per_rank
163163 return dataset_x [idx_start :idx_end ], dataset_y [idx_start :idx_end ]
164+
165+
166+ def train_mnist_cnn (DIST = False ,
167+ local_rank = None ,
168+ world_size = None ,
169+ nccl_id = None ,
170+ spars = 0 ,
171+ topK = False ,
172+ corr = True ):
173+
174+ # Define the hypermeters for the mnist_cnn
175+ max_epoch = 10
176+ batch_size = 64
177+ sgd = opt .SGD (lr = 0.005 , momentum = 0.9 , weight_decay = 1e-5 )
178+
179+ # Prepare training and valadiation data
180+ train_x , train_y , test_x , test_y = load_dataset ()
181+ IMG_SIZE = 28
182+ num_classes = 10
183+ train_y = to_categorical (train_y , num_classes )
184+ test_y = to_categorical (test_y , num_classes )
185+
186+ # Normalization
187+ train_x = train_x / 255
188+ test_x = test_x / 255
189+
190+ if DIST :
191+ # For distributed GPU training
192+ sgd = opt .DistOpt (sgd ,
193+ nccl_id = nccl_id ,
194+ local_rank = local_rank ,
195+ world_size = world_size )
196+ dev = device .create_cuda_gpu_on (sgd .local_rank )
197+
198+ # Dataset partition for distributed training
199+ train_x , train_y = data_partition (train_x , train_y , sgd .global_rank ,
200+ sgd .world_size )
201+ test_x , test_y = data_partition (test_x , test_y , sgd .global_rank ,
202+ sgd .world_size )
203+ world_size = sgd .world_size
204+ else :
205+ # For single GPU
206+ dev = device .create_cuda_gpu ()
207+ world_size = 1
208+
209+ # Create model
210+ model = CNN ()
211+
212+ tx = tensor .Tensor ((batch_size , 1 , IMG_SIZE , IMG_SIZE ), dev , tensor .float32 )
213+ ty = tensor .Tensor ((batch_size , num_classes ), dev , tensor .int32 )
214+ num_train_batch = train_x .shape [0 ] // batch_size
215+ num_test_batch = test_x .shape [0 ] // batch_size
216+ idx = np .arange (train_x .shape [0 ], dtype = np .int32 )
217+
218+ if DIST :
219+ #Sychronize the initial parameters
220+ autograd .training = True
221+ x = np .random .randn (batch_size , 1 , IMG_SIZE ,
222+ IMG_SIZE ).astype (np .float32 )
223+ y = np .zeros (shape = (batch_size , num_classes ), dtype = np .int32 )
224+ tx .copy_from_numpy (x )
225+ ty .copy_from_numpy (y )
226+ out = model .forward (tx )
227+ loss = autograd .softmax_cross_entropy (out , ty )
228+ for p , g in autograd .backward (loss ):
229+ synchronize (p , sgd )
230+
231+ # Training and evaulation loop
232+ for epoch in range (max_epoch ):
233+ start_time = time .time ()
234+ np .random .shuffle (idx )
235+
236+ if ((DIST == False ) or (sgd .global_rank == 0 )):
237+ print ('Starting Epoch %d:' % (epoch ))
238+
239+ # Training phase
240+ autograd .training = True
241+ train_correct = np .zeros (shape = [1 ], dtype = np .float32 )
242+ test_correct = np .zeros (shape = [1 ], dtype = np .float32 )
243+ train_loss = np .zeros (shape = [1 ], dtype = np .float32 )
244+
245+ for b in range (num_train_batch ):
246+ x = train_x [idx [b * batch_size :(b + 1 ) * batch_size ]]
247+ x = augmentation (x , batch_size )
248+ y = train_y [idx [b * batch_size :(b + 1 ) * batch_size ]]
249+ tx .copy_from_numpy (x )
250+ ty .copy_from_numpy (y )
251+ out = model .forward (tx )
252+ loss = autograd .softmax_cross_entropy (out , ty )
253+ train_correct += accuracy (tensor .to_numpy (out ), y )
254+ train_loss += tensor .to_numpy (loss )[0 ]
255+ if DIST :
256+ if (spars == 0 ):
257+ sgd .backward_and_update (loss , threshold = 50000 )
258+ else :
259+ sgd .backward_and_sparse_update (loss ,
260+ spars = spars ,
261+ topK = topK ,
262+ corr = corr )
263+ else :
264+ sgd (loss )
265+
266+ if DIST :
267+ # Reduce the evaluation accuracy and loss from multiple devices
268+ reducer = tensor .Tensor ((1 ,), dev , tensor .float32 )
269+ train_correct = reduce_variable (train_correct , sgd , reducer )
270+ train_loss = reduce_variable (train_loss , sgd , reducer )
271+
272+ # Output the training loss and accuracy
273+ if ((DIST == False ) or (sgd .global_rank == 0 )):
274+ print ('Training loss = %f, training accuracy = %f' %
275+ (train_loss , train_correct /
276+ (num_train_batch * batch_size * world_size )),
277+ flush = True )
278+
279+ # Evaluation phase
280+ autograd .training = False
281+ for b in range (num_test_batch ):
282+ x = test_x [b * batch_size :(b + 1 ) * batch_size ]
283+ y = test_y [b * batch_size :(b + 1 ) * batch_size ]
284+ tx .copy_from_numpy (x )
285+ ty .copy_from_numpy (y )
286+ out_test = model .forward (tx )
287+ test_correct += accuracy (tensor .to_numpy (out_test ), y )
288+
289+ if DIST :
290+ # Reduce the evaulation accuracy from multiple devices
291+ test_correct = reduce_variable (test_correct , sgd , reducer )
292+
293+ # Output the evaluation accuracy
294+ if ((DIST == False ) or (sgd .global_rank == 0 )):
295+ print ('Evaluation accuracy = %f, Elapsed Time = %fs' %
296+ (test_correct / (num_test_batch * batch_size * world_size ),
297+ time .time () - start_time ),
298+ flush = True )
299+
300+
301+ if __name__ == '__main__' :
302+
303+ DIST = False
304+ train_mnist_cnn (DIST = DIST )
0 commit comments