Skip to content

Commit 777be5a

Browse files
authored
Merge pull request #220 from vslyu/fix_save_step
fix save step bug
2 parents 1671a37 + 8e5db1c commit 777be5a

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

core/trainers/framework/runner.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,10 @@ def _executor_dataloader_train(self, model_dict, context):
209209

210210
if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[
211211
"is_infer"] == False:
212-
if context["fleet_mode"]:
213-
if context["fleet_mode"].upper() == "PS":
214-
train_prog = context["model"][model_dict[
215-
"name"]]["main_program"]
216-
elif not context["is_fleet"] or context[
217-
"fleet_mode"].upper() == "COLLECTIVE":
212+
if context["fleet_mode"].upper() == "PS":
213+
train_prog = context["model"][model_dict["name"]][
214+
"main_program"]
215+
else:
218216
train_prog = context["model"][model_dict["name"]][
219217
"default_main_program"]
220218
startup_prog = context["model"][model_dict["name"]][

models/rank/dnn/config.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,13 @@ runner:
114114
print_interval: 1
115115
phases: [phase1]
116116

117-
- name: local_ps_train
118-
class: local_cluster_train
117+
- name: single_multi_gpu_train
118+
class: train
119119
# num of epochs
120120
epochs: 1
121121
# device to run training or infer
122-
device: cpu
123-
selected_gpus: "0" # 选择多卡执行训练
124-
work_num: 1
125-
server_num: 1
122+
device: gpu
123+
selected_gpus: "0,1" # 选择多卡执行训练
126124
save_checkpoint_interval: 1 # save model interval of epochs
127125
save_inference_interval: 4 # save inference
128126
save_step_interval: 1

0 commit comments

Comments
 (0)