Skip to content

Commit 9e026a9

Browse files
committed
remove chief
1 parent 7fbddaa commit 9e026a9

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

python/paddle/fluid/io.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ def get_parameter_value_by_name(name, executor, program=None):
466466
def save_checkpoint(executor,
467467
checkpoint_dir,
468468
trainer_id,
469-
is_chief=False,
470469
trainer_args=None,
471470
main_program=None,
472471
max_num_checkpoints=3):
@@ -478,8 +477,7 @@ def save_checkpoint(executor,
478477
479478
:param executor executor for save the value
480479
:param checkpoint_dir the checkpoint directory
481-
:param trainer_id currect trainer id
482-
:param is_chief if the trainer id equals 0, the is_chief will be true
480+
:param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
483481
:param main_program will save all variables in program
484482
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
485483
"""
@@ -497,7 +495,7 @@ def save_checkpoint(executor,
497495

498496
save_trainer_args(cur_dir, trainer_id, trainer_args)
499497

500-
if is_chief:
498+
if trainer_id == 0:
501499
save_persist_vars_without_grad(executor, cur_dir, main_program)
502500

503501
_scroll_delete(checkpoint_dir, max_num_checkpoints)

python/paddle/fluid/trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def __init__(self,
136136
# config for checkpoint
137137
# only chief worker will save variables
138138
self.trainer_id = 0
139-
self.chief = True
140139
self.checkpoint_cfg = checkpoint_config
141140
if self.checkpoint_cfg:
142141
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
@@ -201,7 +200,6 @@ def _transpile_nccl2_dist(self):
201200
self.nccl_id_var = None
202201
else:
203202
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
204-
self.chief = self.trainer_id == 0
205203
port = os.getenv("PADDLE_PSERVER_PORT")
206204
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
207205
worker_endpoints = []
@@ -250,7 +248,7 @@ def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
250248
# the unique trainer id, starting from 0, needed by trainer
251249
# only
252250
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
253-
self.chief = self.trainer_id == 0
251+
254252
# the role, should be either PSERVER or TRAINER
255253
training_role = os.getenv("PADDLE_TRAINING_ROLE")
256254
with self._prog_and_scope_guard():
@@ -456,7 +454,6 @@ def _save_checkpoint(self, epoch_id, step_id):
456454
executor=exe,
457455
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
458456
trainer_id=self.trainer_id,
459-
is_chief=self.chief,
460457
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
461458
main_program=self.train_program,
462459
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)

0 commit comments

Comments
 (0)