Skip to content

Commit 50cd5a7

Browse files
committed
fix save model
1 parent 9ebd2c6 commit 50cd5a7

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

tools/static_trainer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,20 @@ def main(args):
143143
else:
144144
logger.info("reader type wrong")
145145

146-
save_static_model(
147-
paddle.static.default_main_program(),
148-
model_save_path,
149-
epoch_id,
150-
prefix='rec_static')
146+
if use_fleet:
147+
trainer_id = paddle.distributed.get_rank()
148+
if trainer_id == 0:
149+
save_static_model(
150+
paddle.static.default_main_program(),
151+
model_save_path,
152+
epoch_id,
153+
prefix='rec_static')
154+
else:
155+
save_static_model(
156+
paddle.static.default_main_program(),
157+
model_save_path,
158+
epoch_id,
159+
prefix='rec_static')
151160

152161
if use_inference:
153162
feed_var_names = config.get("runner.save_inference_feed_varnames",

tools/trainer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,18 @@ def main(args):
195195
tensor_print_str + " epoch time: {:.2f} s".format(
196196
time.time() - epoch_begin))
197197

198-
save_model(
199-
dy_model, optimizer, model_save_path, epoch_id, prefix='rec')
198+
if use_fleet:
199+
trainer_id = paddle.distributed.get_rank()
200+
if trainer_id == 0:
201+
save_model(
202+
dy_model,
203+
optimizer,
204+
model_save_path,
205+
epoch_id,
206+
prefix='rec')
207+
else:
208+
save_model(
209+
dy_model, optimizer, model_save_path, epoch_id, prefix='rec')
200210

201211

202212
if __name__ == '__main__':

0 commit comments

Comments
 (0)