Skip to content

Commit ebeb23c

Browse files
committed
some question
2 parents d837c5f + 777be5a commit ebeb23c

File tree

4 files changed

+90
-8
lines changed

4 files changed

+90
-8
lines changed

core/trainers/framework/runner.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _executor_dataset_train(self, model_dict, context):
100100
fetch_period = int(
101101
envs.get_global_env("runner." + context["runner_name"] +
102102
".print_interval", 20))
103+
103104
scope = context["model"][model_name]["scope"]
104105
program = context["model"][model_name]["main_program"]
105106
reader = context["dataset"][reader_name]
@@ -139,6 +140,9 @@ def _executor_dataloader_train(self, model_dict, context):
139140
fetch_period = int(
140141
envs.get_global_env("runner." + context["runner_name"] +
141142
".print_interval", 20))
143+
save_step_interval = int(
144+
envs.get_global_env("runner." + context["runner_name"] +
145+
".save_step_interval", -1))
142146
if context["is_infer"]:
143147
metrics = model_class.get_infer_results()
144148
else:
@@ -202,6 +206,24 @@ def _executor_dataloader_train(self, model_dict, context):
202206
metrics_logging.insert(1, seconds)
203207
begin_time = end_time
204208
logging.info(metrics_format.format(*metrics_logging))
209+
210+
if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[
211+
"is_infer"] == False:
212+
if context["fleet_mode"].upper() == "PS":
213+
train_prog = context["model"][model_dict["name"]][
214+
"main_program"]
215+
else:
216+
train_prog = context["model"][model_dict["name"]][
217+
"default_main_program"]
218+
startup_prog = context["model"][model_dict["name"]][
219+
"startup_program"]
220+
with fluid.program_guard(train_prog, startup_prog):
221+
self.save(
222+
context,
223+
is_fleet=context["is_fleet"],
224+
epoch_id=None,
225+
batch_id=batch_id)
226+
205227
batch_id += 1
206228
except fluid.core.EOFException:
207229
reader.reset()
@@ -314,7 +336,7 @@ def _get_ps_program(self, model_dict, context):
314336
exec_strategy=_exe_strategy)
315337
return program
316338

317-
def save(self, epoch_id, context, is_fleet=False):
339+
def save(self, context, is_fleet=False, epoch_id=None, batch_id=None):
318340
def need_save(epoch_id, epoch_interval, is_last=False):
319341
name = "runner." + context["runner_name"] + "."
320342
total_epoch = int(envs.get_global_env(name + "epochs", 1))
@@ -371,7 +393,8 @@ def save_inference_model():
371393

372394
assert dirname is not None
373395
dirname = os.path.join(dirname, str(epoch_id))
374-
396+
logging.info("\tsave epoch_id:%d model into: \"%s\"" %
397+
(epoch_id, dirname))
375398
if is_fleet:
376399
warnings.warn(
377400
"Save inference model in cluster training is not recommended! Using save checkpoint instead.",
@@ -394,14 +417,35 @@ def save_persistables():
394417
if dirname is None or dirname == "":
395418
return
396419
dirname = os.path.join(dirname, str(epoch_id))
420+
logging.info("\tsave epoch_id:%d model into: \"%s\"" %
421+
(epoch_id, dirname))
422+
if is_fleet:
423+
if context["fleet"].worker_index() == 0:
424+
context["fleet"].save_persistables(context["exe"], dirname)
425+
else:
426+
fluid.io.save_persistables(context["exe"], dirname)
427+
428+
def save_checkpoint_step():
429+
name = "runner." + context["runner_name"] + "."
430+
save_interval = int(
431+
envs.get_global_env(name + "save_step_interval", -1))
432+
dirname = envs.get_global_env(name + "save_step_path", None)
433+
if dirname is None or dirname == "":
434+
return
435+
dirname = os.path.join(dirname, str(batch_id))
436+
logging.info("\tsave batch_id:%d model into: \"%s\"" %
437+
(batch_id, dirname))
397438
if is_fleet:
398439
if context["fleet"].worker_index() == 0:
399440
context["fleet"].save_persistables(context["exe"], dirname)
400441
else:
401442
fluid.io.save_persistables(context["exe"], dirname)
402443

403-
save_persistables()
404-
save_inference_model()
444+
if isinstance(epoch_id, int):
445+
save_persistables()
446+
save_inference_model()
447+
if isinstance(batch_id, int):
448+
save_checkpoint_step()
405449

406450

407451
class SingleRunner(RunnerBase):
@@ -453,7 +497,7 @@ def run(self, context):
453497
startup_prog = context["model"][model_dict["name"]][
454498
"startup_program"]
455499
with fluid.program_guard(train_prog, startup_prog):
456-
self.save(epoch, context)
500+
self.save(context=context, epoch_id=epoch)
457501
context["status"] = "terminal_pass"
458502

459503

@@ -506,7 +550,7 @@ def run(self, context):
506550
startup_prog = context["model"][model_dict["name"]][
507551
"startup_program"]
508552
with fluid.program_guard(train_prog, startup_prog):
509-
self.save(epoch, context, True)
553+
self.save(context=context, is_fleet=True, epoch_id=epoch)
510554
context["status"] = "terminal_pass"
511555

512556

@@ -539,7 +583,7 @@ def run(self, context):
539583
startup_prog = context["model"][model_dict["name"]][
540584
"startup_program"]
541585
with fluid.program_guard(train_prog, startup_prog):
542-
self.save(epoch, context, True)
586+
self.save(context=context, is_fleet=True, epoch_id=epoch)
543587
context["status"] = "terminal_pass"
544588

545589

core/utils/envs.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
import six
2222
import traceback
23-
import six
23+
import warnings
2424

2525
global_envs = {}
2626
global_envs_flatten = {}
@@ -98,6 +98,25 @@ def fatten_env_namespace(namespace_nests, local_envs):
9898
value = os_path_adapter(workspace_adapter(value))
9999
global_envs[name] = value
100100

101+
for runner in envs["runner"]:
102+
if "save_step_interval" in runner or "save_step_path" in runner:
103+
phase_name = runner["phases"]
104+
phase = [
105+
phase for phase in envs["phase"]
106+
if phase["name"] == phase_name[0]
107+
]
108+
dataset_name = phase[0].get("dataset_name")
109+
dataset = [
110+
dataset for dataset in envs["dataset"]
111+
if dataset["name"] == dataset_name
112+
]
113+
if dataset[0].get("type") == "QueueDataset":
114+
runner["save_step_interval"] = None
115+
runner["save_step_path"] = None
116+
warnings.warn(
117+
"QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml"
118+
)
119+
101120
if get_platform() != "LINUX":
102121
for dataset in envs["dataset"]:
103122
name = ".".join(["dataset", dataset["name"], "type"])

doc/yaml.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
| init_model_path | string | 路径 || 初始化模型地址 |
2828
| save_checkpoint_interval | int | >= 1 || Save参数的轮数间隔 |
2929
| save_checkpoint_path | string | 路径 || Save参数的地址 |
30+
| save_step_interval | int | >= 1 || Step save参数的batch数间隔 |
31+
| save_step_path | string | 路径 || Step save参数的地址 |
3032
| save_inference_interval | int | >= 1 || Save预测模型的轮数间隔 |
3133
| save_inference_path | string | 路径 || Save预测模型的地址 |
3234
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name || 预测模型的入口变量name |

models/rank/dnn/config.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,23 @@ runner:
114114
print_interval: 1
115115
phases: [phase1]
116116

117+
- name: single_multi_gpu_train
118+
class: train
119+
# num of epochs
120+
epochs: 1
121+
# device to run training or infer
122+
device: gpu
123+
selected_gpus: "0,1" # 选择多卡执行训练
124+
save_checkpoint_interval: 1 # save model interval of epochs
125+
save_inference_interval: 4 # save inference
126+
save_step_interval: 1
127+
save_checkpoint_path: "increment_dnn" # save checkpoint path
128+
save_inference_path: "inference" # save inference path
129+
save_step_path: "step_save"
130+
save_inference_feed_varnames: [] # feed vars of save inference
131+
save_inference_fetch_varnames: [] # fetch vars of save inference
132+
print_interval: 1
133+
phases: [phase1]
117134
# runner will run all the phase in each epoch
118135
phase:
119136
- name: phase1

0 commit comments

Comments
 (0)