Skip to content

Commit bb786d8

Browse files
authored
* support Partial FC (#85)
* support cosface loss function * support resnet100 * remove softmax and dist-softmax loss pass * using model_parallel flag to replace dist-* loss type statement * using margin loss to unify arcface, cosface, sphereface, softmax loss function
1 parent d0b8452 commit bb786d8

File tree

7 files changed

+260
-359
lines changed

7 files changed

+260
-359
lines changed

plsc/config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414

1515
from easydict import EasyDict as edict
16-
17-
1816
"""
1917
Default Parameters
2018
"""
@@ -27,17 +25,22 @@
2725
config.dataset_dir = './train_data'
2826
config.train_image_num = 5822653
2927
config.model_name = 'ResNet50'
30-
config.train_epochs = 120
28+
config.train_epochs = None
29+
config.train_steps = 180000
3130
config.checkpoint_dir = ""
3231
config.with_test = True
3332
config.model_save_dir = "output"
3433
config.warmup_epochs = 0
34+
config.model_parallel = False
3535

36-
config.loss_type = "dist_arcface"
36+
config.loss_type = "arcface"
3737
config.num_classes = 85742
38+
config.sample_ratio = 1.0
3839
config.image_shape = (3, 112, 112)
39-
config.margin = 0.5
40+
config.margin1 = 1.0
41+
config.margin2 = 0.5
42+
config.margin3 = 0.0
4043
config.scale = 64.0
4144
config.lr = 0.1
42-
config.lr_steps = (100000, 160000, 220000)
45+
config.lr_steps = (100000, 160000, 180000)
4346
config.emb_dim = 512

plsc/entry.py

Lines changed: 90 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,6 @@ class Entry(object):
5858
The class to encapsulate all operations.
5959
"""
6060

61-
def _check(self):
62-
"""
63-
Check the validation of parameters.
64-
"""
65-
supported_types = [
66-
"softmax",
67-
"arcface",
68-
"dist_softmax",
69-
"dist_arcface",
70-
]
71-
assert self.loss_type in supported_types, \
72-
"All supported types are {}, but given {}.".format(
73-
supported_types, self.loss_type)
74-
75-
if self.loss_type in ["dist_softmax", "dist_arcface"]:
76-
assert self.num_trainers > 1, \
77-
"At least 2 trainers are required for distributed fc-layer. " \
78-
"You can start your job using paddle.distributed.launch module."
79-
8061
def __init__(self):
8162
self.config = config.config
8263
super(Entry, self).__init__()
@@ -118,8 +99,12 @@ def __init__(self):
11899
self.val_targets = self.config.val_targets
119100
self.dataset_dir = self.config.dataset_dir
120101
self.num_classes = self.config.num_classes
102+
self.sample_ratio = self.config.sample_ratio
103+
self.model_parallel = self.config.model_parallel
121104
self.loss_type = self.config.loss_type
122-
self.margin = self.config.margin
105+
self.margin1 = self.config.margin1
106+
self.margin2 = self.config.margin2
107+
self.margin3 = self.config.margin3
123108
self.scale = self.config.scale
124109
self.lr = self.config.lr
125110
self.lr_steps = self.config.lr_steps
@@ -128,12 +113,16 @@ def __init__(self):
128113
self.model_name = self.config.model_name
129114
self.emb_dim = self.config.emb_dim
130115
self.train_epochs = self.config.train_epochs
116+
self.train_steps = self.config.train_steps
131117
self.checkpoint_dir = self.config.checkpoint_dir
132118
self.with_test = self.config.with_test
133119
self.model_save_dir = self.config.model_save_dir
134120
self.warmup_epochs = self.config.warmup_epochs
135121
self.calc_train_acc = False
136122

123+
assert not (self.train_epochs and self.train_steps
124+
), 'train_steps and train_epochs only one can be set'
125+
137126
self.max_last_checkpoint_num = 5
138127
if self.checkpoint_dir:
139128
self.checkpoint_dir = os.path.abspath(self.checkpoint_dir)
@@ -166,6 +155,8 @@ def __init__(self):
166155
logger.info('\t' + str(key) + ": " + str(self.config[key]))
167156
logger.info('trainer_id: {}, num_trainers: {}'.format(trainer_id,
168157
num_trainers))
158+
logger.info('global_train_batch_size: {}'.format(
159+
self.global_train_batch_size))
169160
logger.info('default lr_decay_factor: {}'.format(self.lr_decay_factor))
170161
logger.info('default log period: {}'.format(self.log_period))
171162
logger.info('default test period: {}'.format(self.test_period))
@@ -327,6 +318,23 @@ def set_class_num(self, num):
327318
self.num_classes = num
328319
logger.info("Set num_classes to {}.".format(num))
329320

321+
def set_model_parallel(self, flag):
322+
"""
323+
Set the flag of model parallel.
324+
"""
325+
self.model_parallel = flag
326+
if flag:
327+
assert self.num_trainers > 1, "The number of GPUs must greater " \
328+
"than 1 when using model parallel training"
329+
logger.info("Set model_parallel to {}.".format(flag))
330+
331+
def set_sample_ratio(self, sample_ratio):
332+
"""
333+
Set the sample ratio of Partial FC.
334+
"""
335+
self.sample_ratio = sample_ratio
336+
logger.info("Set sample_ratio to {}.".format(sample_ratio))
337+
330338
def set_emb_size(self, size):
331339
"""
332340
Set the size of the last hidding layer before the distributed fc-layer.
@@ -348,9 +356,18 @@ def set_train_epochs(self, num):
348356
"""
349357
Set the number of epochs to train.
350358
"""
359+
self.train_steps = None
351360
self.train_epochs = num
352361
logger.info("Set train_epochs to {}.".format(num))
353362

363+
def set_train_steps(self, num):
364+
"""
365+
Set the number of steps to train.
366+
"""
367+
self.train_epochs = None
368+
self.train_steps = num
369+
logger.info("Set train_steps to {}.".format(num))
370+
354371
def set_checkpoint_dir(self, directory):
355372
"""
356373
Set the directory for checkpoint loaded before training/testing.
@@ -371,15 +388,39 @@ def set_warmup_epochs(self, num):
371388
self.warmup_epochs = num
372389
logger.info("Set warmup_epochs to {}.".format(num))
373390

374-
def set_loss_type(self, loss_type):
375-
supported_types = [
376-
"dist_softmax", "dist_arcface", "softmax", "arcface"
377-
]
378-
if loss_type not in supported_types:
379-
raise ValueError("All supported loss types: {}".format(
380-
supported_types))
391+
def set_loss_type(self,
392+
loss_type,
393+
margin1=None,
394+
margin2=None,
395+
margin3=None):
396+
"""
397+
Set the loss type. Supported arcface, cosface, sphereface loss type.
398+
You also can set combined margin loss by yourself via marign1, margin2, maring3.
399+
"""
381400
self.loss_type = loss_type
382-
logger.info("Set loss_type to {}.".format(loss_type))
401+
if "arcface" == loss_type:
402+
self.margin1 = 1.0 if margin1 is None else margin1
403+
self.margin2 = 0.5 if margin2 is None else margin2
404+
self.margin3 = 0.0 if margin3 is None else margin3
405+
elif "cosface" == loss_type:
406+
self.margin1 = 1.0 if margin1 is None else margin1
407+
self.margin2 = 0.0 if margin2 is None else margin2
408+
self.margin3 = 0.35 if margin3 is None else margin3
409+
elif "sphereface" == loss_type:
410+
self.margin1 = 1.35 if margin1 is None else margin1
411+
self.margin2 = 0.0 if margin2 is None else margin2
412+
self.margin3 = 0.0 if margin3 is None else margin3
413+
else:
414+
self.margin1 = margin1
415+
self.margin2 = margin2
416+
self.margin3 = margin3
417+
assert self.margin1 is not None, "margin1 must be set"
418+
assert self.margin2 is not None, "margin2 must be set"
419+
assert self.margin3 is not None, "margin3 must be set"
420+
421+
logger.info(
422+
"Set loss_type to {}, margin1 = {}, margin2 = {}, margin3 = {}.".
423+
format(loss_type, self.margin1, self.margin2, self.margin3))
383424

384425
def set_optimizer(self, optimizer):
385426
if not isinstance(optimizer, Optimizer):
@@ -421,6 +462,8 @@ def _get_optimizer(self):
421462
steps_per_pass = int(
422463
math.ceil(images_per_trainer * 1.0 / self.train_batch_size))
423464
logger.info("Steps per epoch: %d" % steps_per_pass)
465+
if self.train_epochs is None:
466+
self.train_epochs = self.train_steps // steps_per_pass + 1
424467
warmup_steps = steps_per_pass * self.warmup_epochs
425468
batch_denom = 1024
426469
base_lr = start_lr * global_batch_size / batch_denom
@@ -445,12 +488,11 @@ def _get_optimizer(self):
445488
weight_decay=paddle.regularizer.L2Decay(5e-4))
446489
self.optimizer = optimizer
447490

448-
if self.loss_type in ["dist_softmax", "dist_arcface"]:
491+
if self.model_parallel:
449492
self.optimizer = DistributedClassificationOptimizer(
450493
self.optimizer,
451494
self.train_batch_size,
452495
use_fp16=self.use_fp16,
453-
loss_type=self.loss_type,
454496
fp16_user_dict=self.fp16_user_dict)
455497
elif self.use_fp16:
456498
self.optimizer = paddle.static.amp.decorate(
@@ -486,23 +528,32 @@ def build_program(self, is_train=True, use_parallel_test=False):
486528
input_field.build()
487529
self.input_field = input_field
488530

531+
if self.model_parallel:
532+
msg = 'Using model parallelism for training.'
533+
logger.info(msg)
534+
if self.sample_ratio < 1.0:
535+
msg = 'Using Partial FC and sample ratio = %.2f.' % self.sample_ratio
536+
logger.info(msg)
489537
emb, loss, prob = model.get_output(
490538
input=input_field,
491539
num_classes=self.num_classes,
492540
num_ranks=num_trainers,
493541
rank_id=trainer_id,
542+
model_parallel=self.model_parallel,
494543
is_train=is_train,
495544
param_attr=self.param_attr,
496545
bias_attr=self.bias_attr,
497-
loss_type=self.loss_type,
498-
margin=self.margin,
499-
scale=self.scale)
546+
margin1=self.margin1,
547+
margin2=self.margin2,
548+
margin3=self.margin3,
549+
scale=self.scale,
550+
sample_ratio=self.sample_ratio)
500551

501552
acc1 = None
502553
acc5 = None
503554

504-
if self.loss_type in ["dist_softmax", "dist_arcface"]:
505-
if self.calc_train_acc:
555+
if self.calc_train_acc:
556+
if self.model_parallel:
506557
shard_prob = loss._get_info("shard_prob")
507558

508559
prob_list = []
@@ -520,8 +571,7 @@ def build_program(self, is_train=True, use_parallel_test=False):
520571
input=prob,
521572
label=paddle.reshape(label_all, (-1, 1)),
522573
k=5)
523-
else:
524-
if self.calc_train_acc:
574+
else:
525575
acc1 = paddle.static.accuracy(
526576
input=prob,
527577
label=paddle.reshape(input_field.label, (-1, 1)),
@@ -540,7 +590,7 @@ def build_program(self, is_train=True, use_parallel_test=False):
540590
dist_optimizer.minimize(loss)
541591
else: # single card training
542592
optimizer.minimize(loss)
543-
if "dist" in self.loss_type or self.use_fp16:
593+
if self.model_parallel or self.use_fp16:
544594
optimizer = optimizer._optimizer
545595
elif use_parallel_test:
546596
emb_list = []
@@ -714,9 +764,7 @@ def load(self, program, for_train=True):
714764
else:
715765
state_dict[name] = tensor
716766

717-
distributed = self.loss_type in ["dist_softmax", "dist_arcface"]
718-
719-
if for_train or distributed:
767+
if for_train or self.model_parallel:
720768
meta_file = os.path.join(checkpoint_dir, 'meta.json')
721769
if not os.path.exists(meta_file):
722770
logger.error(
@@ -729,7 +777,7 @@ def load(self, program, for_train=True):
729777
config = json.load(handle)
730778

731779
# Preporcess distributed parameters.
732-
if distributed:
780+
if self.model_parallel:
733781
pretrain_nranks = config['pretrain_nranks']
734782
assert pretrain_nranks > 0
735783
emb_dim = config['emb_dim']
@@ -899,8 +947,6 @@ def _run_test(self, exe, test_list, test_name_list, feeder, fetch_list):
899947
sys.stdout.flush()
900948

901949
def test(self):
902-
self._check()
903-
904950
trainer_id = self.trainer_id
905951
num_trainers = self.num_trainers
906952

@@ -979,7 +1025,6 @@ def test(self):
9791025
logger.info("test time: {:.4f}".format(test_end - test_start))
9801026

9811027
def train(self):
982-
self._check()
9831028
self.has_run_train = True
9841029

9851030
trainer_id = self.trainer_id

0 commit comments

Comments
 (0)