@@ -58,25 +58,6 @@ class Entry(object):
5858 The class to encapsulate all operations.
5959 """
6060
61- def _check (self ):
62- """
63- Check the validation of parameters.
64- """
65- supported_types = [
66- "softmax" ,
67- "arcface" ,
68- "dist_softmax" ,
69- "dist_arcface" ,
70- ]
71- assert self .loss_type in supported_types , \
72- "All supported types are {}, but given {}." .format (
73- supported_types , self .loss_type )
74-
75- if self .loss_type in ["dist_softmax" , "dist_arcface" ]:
76- assert self .num_trainers > 1 , \
77- "At least 2 trainers are required for distributed fc-layer. " \
78- "You can start your job using paddle.distributed.launch module."
79-
8061 def __init__ (self ):
8162 self .config = config .config
8263 super (Entry , self ).__init__ ()
@@ -118,8 +99,12 @@ def __init__(self):
11899 self .val_targets = self .config .val_targets
119100 self .dataset_dir = self .config .dataset_dir
120101 self .num_classes = self .config .num_classes
102+ self .sample_ratio = self .config .sample_ratio
103+ self .model_parallel = self .config .model_parallel
121104 self .loss_type = self .config .loss_type
122- self .margin = self .config .margin
105+ self .margin1 = self .config .margin1
106+ self .margin2 = self .config .margin2
107+ self .margin3 = self .config .margin3
123108 self .scale = self .config .scale
124109 self .lr = self .config .lr
125110 self .lr_steps = self .config .lr_steps
@@ -128,12 +113,16 @@ def __init__(self):
128113 self .model_name = self .config .model_name
129114 self .emb_dim = self .config .emb_dim
130115 self .train_epochs = self .config .train_epochs
116+ self .train_steps = self .config .train_steps
131117 self .checkpoint_dir = self .config .checkpoint_dir
132118 self .with_test = self .config .with_test
133119 self .model_save_dir = self .config .model_save_dir
134120 self .warmup_epochs = self .config .warmup_epochs
135121 self .calc_train_acc = False
136122
123+ assert not (self .train_epochs and self .train_steps
124+ ), 'train_steps and train_epochs only one can be set'
125+
137126 self .max_last_checkpoint_num = 5
138127 if self .checkpoint_dir :
139128 self .checkpoint_dir = os .path .abspath (self .checkpoint_dir )
@@ -166,6 +155,8 @@ def __init__(self):
166155 logger .info ('\t ' + str (key ) + ": " + str (self .config [key ]))
167156 logger .info ('trainer_id: {}, num_trainers: {}' .format (trainer_id ,
168157 num_trainers ))
158+ logger .info ('global_train_batch_size: {}' .format (
159+ self .global_train_batch_size ))
169160 logger .info ('default lr_decay_factor: {}' .format (self .lr_decay_factor ))
170161 logger .info ('default log period: {}' .format (self .log_period ))
171162 logger .info ('default test period: {}' .format (self .test_period ))
@@ -327,6 +318,23 @@ def set_class_num(self, num):
327318 self .num_classes = num
328319 logger .info ("Set num_classes to {}." .format (num ))
329320
321+ def set_model_parallel (self , flag ):
322+ """
323+ Set the flag of model parallel.
324+ """
325+ self .model_parallel = flag
326+ if flag :
327+ assert self .num_trainers > 1 , "The number of GPUs must greater " \
328+ "than 1 when using model parallel training"
329+ logger .info ("Set model_parallel to {}." .format (flag ))
330+
331+ def set_sample_ratio (self , sample_ratio ):
332+ """
333+ Set the sample ratio of Partial FC.
334+ """
335+ self .sample_ratio = sample_ratio
336+ logger .info ("Set sample_ratio to {}." .format (sample_ratio ))
337+
330338 def set_emb_size (self , size ):
331339 """
332340 Set the size of the last hidding layer before the distributed fc-layer.
@@ -348,9 +356,18 @@ def set_train_epochs(self, num):
348356 """
349357 Set the number of epochs to train.
350358 """
359+ self .train_steps = None
351360 self .train_epochs = num
352361 logger .info ("Set train_epochs to {}." .format (num ))
353362
363+ def set_train_steps (self , num ):
364+ """
365+ Set the number of steps to train.
366+ """
367+ self .train_epochs = None
368+ self .train_steps = num
369+ logger .info ("Set train_steps to {}." .format (num ))
370+
354371 def set_checkpoint_dir (self , directory ):
355372 """
356373 Set the directory for checkpoint loaded before training/testing.
@@ -371,15 +388,39 @@ def set_warmup_epochs(self, num):
371388 self .warmup_epochs = num
372389 logger .info ("Set warmup_epochs to {}." .format (num ))
373390
374- def set_loss_type (self , loss_type ):
375- supported_types = [
376- "dist_softmax" , "dist_arcface" , "softmax" , "arcface"
377- ]
378- if loss_type not in supported_types :
379- raise ValueError ("All supported loss types: {}" .format (
380- supported_types ))
391+ def set_loss_type (self ,
392+ loss_type ,
393+ margin1 = None ,
394+ margin2 = None ,
395+ margin3 = None ):
396+ """
397+ Set the loss type. Supported arcface, cosface, sphereface loss type.
398+ You also can set combined margin loss by yourself via marign1, margin2, maring3.
399+ """
381400 self .loss_type = loss_type
382- logger .info ("Set loss_type to {}." .format (loss_type ))
401+ if "arcface" == loss_type :
402+ self .margin1 = 1.0 if margin1 is None else margin1
403+ self .margin2 = 0.5 if margin2 is None else margin2
404+ self .margin3 = 0.0 if margin3 is None else margin3
405+ elif "cosface" == loss_type :
406+ self .margin1 = 1.0 if margin1 is None else margin1
407+ self .margin2 = 0.0 if margin2 is None else margin2
408+ self .margin3 = 0.35 if margin3 is None else margin3
409+ elif "sphereface" == loss_type :
410+ self .margin1 = 1.35 if margin1 is None else margin1
411+ self .margin2 = 0.0 if margin2 is None else margin2
412+ self .margin3 = 0.0 if margin3 is None else margin3
413+ else :
414+ self .margin1 = margin1
415+ self .margin2 = margin2
416+ self .margin3 = margin3
417+ assert self .margin1 is not None , "margin1 must be set"
418+ assert self .margin2 is not None , "margin2 must be set"
419+ assert self .margin3 is not None , "margin3 must be set"
420+
421+ logger .info (
422+ "Set loss_type to {}, margin1 = {}, margin2 = {}, margin3 = {}." .
423+ format (loss_type , self .margin1 , self .margin2 , self .margin3 ))
383424
384425 def set_optimizer (self , optimizer ):
385426 if not isinstance (optimizer , Optimizer ):
@@ -421,6 +462,8 @@ def _get_optimizer(self):
421462 steps_per_pass = int (
422463 math .ceil (images_per_trainer * 1.0 / self .train_batch_size ))
423464 logger .info ("Steps per epoch: %d" % steps_per_pass )
465+ if self .train_epochs is None :
466+ self .train_epochs = self .train_steps // steps_per_pass + 1
424467 warmup_steps = steps_per_pass * self .warmup_epochs
425468 batch_denom = 1024
426469 base_lr = start_lr * global_batch_size / batch_denom
@@ -445,12 +488,11 @@ def _get_optimizer(self):
445488 weight_decay = paddle .regularizer .L2Decay (5e-4 ))
446489 self .optimizer = optimizer
447490
448- if self .loss_type in [ "dist_softmax" , "dist_arcface" ] :
491+ if self .model_parallel :
449492 self .optimizer = DistributedClassificationOptimizer (
450493 self .optimizer ,
451494 self .train_batch_size ,
452495 use_fp16 = self .use_fp16 ,
453- loss_type = self .loss_type ,
454496 fp16_user_dict = self .fp16_user_dict )
455497 elif self .use_fp16 :
456498 self .optimizer = paddle .static .amp .decorate (
@@ -486,23 +528,32 @@ def build_program(self, is_train=True, use_parallel_test=False):
486528 input_field .build ()
487529 self .input_field = input_field
488530
531+ if self .model_parallel :
532+ msg = 'Using model parallelism for training.'
533+ logger .info (msg )
534+ if self .sample_ratio < 1.0 :
535+ msg = 'Using Partial FC and sample ratio = %.2f.' % self .sample_ratio
536+ logger .info (msg )
489537 emb , loss , prob = model .get_output (
490538 input = input_field ,
491539 num_classes = self .num_classes ,
492540 num_ranks = num_trainers ,
493541 rank_id = trainer_id ,
542+ model_parallel = self .model_parallel ,
494543 is_train = is_train ,
495544 param_attr = self .param_attr ,
496545 bias_attr = self .bias_attr ,
497- loss_type = self .loss_type ,
498- margin = self .margin ,
499- scale = self .scale )
546+ margin1 = self .margin1 ,
547+ margin2 = self .margin2 ,
548+ margin3 = self .margin3 ,
549+ scale = self .scale ,
550+ sample_ratio = self .sample_ratio )
500551
501552 acc1 = None
502553 acc5 = None
503554
504- if self .loss_type in [ "dist_softmax" , "dist_arcface" ] :
505- if self .calc_train_acc :
555+ if self .calc_train_acc :
556+ if self .model_parallel :
506557 shard_prob = loss ._get_info ("shard_prob" )
507558
508559 prob_list = []
@@ -520,8 +571,7 @@ def build_program(self, is_train=True, use_parallel_test=False):
520571 input = prob ,
521572 label = paddle .reshape (label_all , (- 1 , 1 )),
522573 k = 5 )
523- else :
524- if self .calc_train_acc :
574+ else :
525575 acc1 = paddle .static .accuracy (
526576 input = prob ,
527577 label = paddle .reshape (input_field .label , (- 1 , 1 )),
@@ -540,7 +590,7 @@ def build_program(self, is_train=True, use_parallel_test=False):
540590 dist_optimizer .minimize (loss )
541591 else : # single card training
542592 optimizer .minimize (loss )
543- if "dist" in self .loss_type or self .use_fp16 :
593+ if self .model_parallel or self .use_fp16 :
544594 optimizer = optimizer ._optimizer
545595 elif use_parallel_test :
546596 emb_list = []
@@ -714,9 +764,7 @@ def load(self, program, for_train=True):
714764 else :
715765 state_dict [name ] = tensor
716766
717- distributed = self .loss_type in ["dist_softmax" , "dist_arcface" ]
718-
719- if for_train or distributed :
767+ if for_train or self .model_parallel :
720768 meta_file = os .path .join (checkpoint_dir , 'meta.json' )
721769 if not os .path .exists (meta_file ):
722770 logger .error (
@@ -729,7 +777,7 @@ def load(self, program, for_train=True):
729777 config = json .load (handle )
730778
731779 # Preporcess distributed parameters.
732- if distributed :
780+ if self . model_parallel :
733781 pretrain_nranks = config ['pretrain_nranks' ]
734782 assert pretrain_nranks > 0
735783 emb_dim = config ['emb_dim' ]
@@ -899,8 +947,6 @@ def _run_test(self, exe, test_list, test_name_list, feeder, fetch_list):
899947 sys .stdout .flush ()
900948
901949 def test (self ):
902- self ._check ()
903-
904950 trainer_id = self .trainer_id
905951 num_trainers = self .num_trainers
906952
@@ -979,7 +1025,6 @@ def test(self):
9791025 logger .info ("test time: {:.4f}" .format (test_end - test_start ))
9801026
9811027 def train (self ):
982- self ._check ()
9831028 self .has_run_train = True
9841029
9851030 trainer_id = self .trainer_id
0 commit comments