Skip to content

Commit 9f8abe9

Browse files
修复了Save及Infer阶段的若干bug (#95)
* fix setup * fix bug for dssm reader * fix net bug at PY3 for afm * fix multi cards with files * fix ctr * add validation * add validation * add validation * fix compile * fix ci * fix user define runner * fix gnn reader at PY3 * fix fast yaml config at PY3 Co-authored-by: tangwei <[email protected]>
1 parent c1af414 commit 9f8abe9

File tree

13 files changed

+257
-113
lines changed

13 files changed

+257
-113
lines changed

core/trainers/framework/network.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def build_network(self, context):
9494
context["model"][model_dict["name"]]["model"] = model
9595
context["model"][model_dict["name"]][
9696
"default_main_program"] = train_program.clone()
97+
context["model"][model_dict["name"]]["compiled_program"] = None
9798

9899
context["dataset"] = {}
99100
for dataset in context["env"]["dataset"]:
@@ -149,6 +150,7 @@ def build_network(self, context):
149150
context["model"][model_dict["name"]]["model"] = model
150151
context["model"][model_dict["name"]]["default_main_program"] = context[
151152
"fleet"].main_program.clone()
153+
context["model"][model_dict["name"]]["compiled_program"] = None
152154

153155
if context["fleet"].is_server():
154156
self._server(context)
@@ -245,6 +247,8 @@ def build_network(self, context):
245247
context["model"][model_dict["name"]]["model"] = model
246248
context["model"][model_dict["name"]][
247249
"default_main_program"] = train_program.clone()
250+
context["model"][model_dict["name"]][
251+
"compile_program"] = None
248252

249253
if context["fleet"].is_server():
250254
self._server(context)
@@ -314,6 +318,7 @@ def build_network(self, context):
314318
context["model"][model_dict["name"]]["model"] = model
315319
context["model"][model_dict["name"]][
316320
"default_main_program"] = train_program
321+
context["model"][model_dict["name"]]["compiled_program"] = None
317322

318323
context["dataset"] = {}
319324
for dataset in context["env"]["dataset"]:

core/trainers/framework/runner.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _executor_dataset_train(self, model_dict, context):
5050
reader_name = model_dict["dataset_name"]
5151
model_name = model_dict["name"]
5252
model_class = context["model"][model_dict["name"]]["model"]
53+
5354
fetch_vars = []
5455
fetch_alias = []
5556
fetch_period = int(
@@ -89,19 +90,7 @@ def _executor_dataset_train(self, model_dict, context):
8990
def _executor_dataloader_train(self, model_dict, context):
9091
model_name = model_dict["name"]
9192
model_class = context["model"][model_dict["name"]]["model"]
92-
93-
if context["is_infer"]:
94-
program = context["model"][model_name]["main_program"]
95-
elif context["is_fleet"]:
96-
if context["fleet_mode"].upper() == "PS":
97-
program = self._get_ps_program(model_dict, context)
98-
elif context["fleet_mode"].upper() == "COLLECTIVE":
99-
program = context["model"][model_name]["main_program"]
100-
elif not context["is_fleet"]:
101-
if context["device"].upper() == "CPU":
102-
program = self._get_single_cpu_program(model_dict, context)
103-
elif context["device"].upper() == "GPU":
104-
program = self._get_single_gpu_program(model_dict, context)
93+
program = self._get_dataloader_program(model_dict, context)
10594

10695
reader_name = model_dict["dataset_name"]
10796
fetch_vars = []
@@ -143,6 +132,24 @@ def _executor_dataloader_train(self, model_dict, context):
143132
except fluid.core.EOFException:
144133
reader.reset()
145134

135+
def _get_dataloader_program(self, model_dict, context):
136+
model_name = model_dict["name"]
137+
if context["model"][model_name]["compiled_program"] == None:
138+
if context["is_infer"]:
139+
program = context["model"][model_name]["main_program"]
140+
elif context["is_fleet"]:
141+
if context["fleet_mode"].upper() == "PS":
142+
program = self._get_ps_program(model_dict, context)
143+
elif context["fleet_mode"].upper() == "COLLECTIVE":
144+
program = context["model"][model_name]["main_program"]
145+
elif not context["is_fleet"]:
146+
if context["device"].upper() == "CPU":
147+
program = self._get_single_cpu_program(model_dict, context)
148+
elif context["device"].upper() == "GPU":
149+
program = self._get_single_gpu_program(model_dict, context)
150+
context["model"][model_name]["compiled_program"] = program
151+
return context["model"][model_name]["compiled_program"]
152+
146153
def _get_strategy(self, model_dict, context):
147154
_build_strategy = fluid.BuildStrategy()
148155
_exe_strategy = fluid.ExecutionStrategy()
@@ -218,12 +225,17 @@ def _get_ps_program(self, model_dict, context):
218225

219226
def save(self, epoch_id, context, is_fleet=False):
220227
def need_save(epoch_id, epoch_interval, is_last=False):
228+
name = "runner." + context["runner_name"] + "."
229+
total_epoch = int(envs.get_global_env(name + "epochs", 1))
230+
if epoch_id + 1 == total_epoch:
231+
is_last = True
232+
221233
if is_last:
222234
return True
223235
if epoch_id == -1:
224236
return False
225237

226-
return epoch_id % epoch_interval == 0
238+
return (epoch_id + 1) % epoch_interval == 0
227239

228240
def save_inference_model():
229241
name = "runner." + context["runner_name"] + "."
@@ -415,3 +427,53 @@ def run(self, context):
415427
416428
"""
417429
context["status"] = "terminal_pass"
430+
431+
432+
class SingleInferRunner(RunnerBase):
433+
def __init__(self, context):
434+
print("Running SingleInferRunner.")
435+
pass
436+
437+
def run(self, context):
438+
self._dir_check(context)
439+
440+
for index, epoch_name in enumerate(self.epoch_model_name_list):
441+
for model_dict in context["phases"]:
442+
self._load(context, model_dict,
443+
self.epoch_model_path_list[index])
444+
begin_time = time.time()
445+
self._run(context, model_dict)
446+
end_time = time.time()
447+
seconds = end_time - begin_time
448+
print("Infer {} of {} done, use time: {}".format(model_dict[
449+
"name"], epoch_name, seconds))
450+
context["status"] = "terminal_pass"
451+
452+
def _load(self, context, model_dict, model_path):
453+
if model_path is None or model_path == "":
454+
return
455+
print("load persistables from", model_path)
456+
457+
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
458+
train_prog = context["model"][model_dict["name"]]["main_program"]
459+
startup_prog = context["model"][model_dict["name"]][
460+
"startup_program"]
461+
with fluid.program_guard(train_prog, startup_prog):
462+
fluid.io.load_persistables(
463+
context["exe"], model_path, main_program=train_prog)
464+
465+
def _dir_check(self, context):
466+
dirname = envs.get_global_env(
467+
"runner." + context["runner_name"] + ".init_model_path", None)
468+
self.epoch_model_path_list = []
469+
self.epoch_model_name_list = []
470+
471+
for file in os.listdir(dirname):
472+
file_path = os.path.join(dirname, file)
473+
if os.path.isdir(file_path):
474+
self.epoch_model_path_list.append(file_path)
475+
self.epoch_model_name_list.append(file)
476+
477+
if len(self.epoch_model_path_list) == 0:
478+
self.epoch_model_path_list.append(dirname)
479+
self.epoch_model_name_list.append(dirname)

core/trainers/framework/startup.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,21 @@ def startup(self, context):
101101
context["exe"].run(startup_prog)
102102
self.load(context, True)
103103
context["status"] = "train_pass"
104+
105+
106+
class SingleInferStartup(StartupBase):
107+
def __init__(self, context):
108+
print("Running SingleInferStartup.")
109+
pass
110+
111+
def startup(self, context):
112+
for model_dict in context["phases"]:
113+
with fluid.scope_guard(context["model"][model_dict["name"]][
114+
"scope"]):
115+
train_prog = context["model"][model_dict["name"]][
116+
"main_program"]
117+
startup_prog = context["model"][model_dict["name"]][
118+
"startup_program"]
119+
with fluid.program_guard(train_prog, startup_prog):
120+
context["exe"].run(startup_prog)
121+
context["status"] = "train_pass"

core/trainers/general_trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def startup(self, context):
101101
startup_class = envs.lazy_instance_by_fliename(startup_class_path,
102102
"Startup")(context)
103103
else:
104-
if self.engine == EngineMode.SINGLE:
104+
if self.engine == EngineMode.SINGLE and context["is_infer"]:
105+
startup_class_name = "SingleInferStartup"
106+
elif self.engine == EngineMode.SINGLE and not context["is_infer"]:
105107
startup_class_name = "SingleStartup"
106108
elif self.fleet_mode == FleetMode.PS or self.fleet_mode == FleetMode.PSLIB:
107109
startup_class_name = "PSStartup"
@@ -117,12 +119,14 @@ def startup(self, context):
117119

118120
def runner(self, context):
119121
runner_class_path = envs.get_global_env(
120-
self.runner_env_name + ".runner_class_paht", default_value=None)
122+
self.runner_env_name + ".runner_class_path", default_value=None)
121123
if runner_class_path:
122124
runner_class = envs.lazy_instance_by_fliename(runner_class_path,
123125
"Runner")(context)
124126
else:
125-
if self.engine == EngineMode.SINGLE:
127+
if self.engine == EngineMode.SINGLE and context["is_infer"]:
128+
runner_class_name = "SingleInferRunner"
129+
elif self.engine == EngineMode.SINGLE and not context["is_infer"]:
126130
runner_class_name = "SingleRunner"
127131
elif self.fleet_mode == FleetMode.PSLIB:
128132
runner_class_name = "PslibRunner"

core/utils/validation.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,47 @@
1616

1717

1818
class ValueFormat:
19-
def __init__(self, value_type, value, value_handler):
19+
def __init__(self, value_type, value, value_handler, required=False):
2020
self.value_type = value_type
21-
self.value = value
2221
self.value_handler = value_handler
22+
self.value = value
23+
self.required = required
2324

2425
def is_valid(self, name, value):
25-
ret = self.is_type_valid(name, value)
26+
27+
if not self.value_type:
28+
ret = True
29+
else:
30+
ret = self.is_type_valid(name, value)
31+
2632
if not ret:
2733
return ret
2834

35+
if not self.value or not self.value_handler:
36+
return True
37+
2938
ret = self.is_value_valid(name, value)
3039
return ret
3140

3241
def is_type_valid(self, name, value):
3342
if self.value_type == "int":
3443
if not isinstance(value, int):
3544
print("\nattr {} should be int, but {} now\n".format(
36-
name, self.value_type))
45+
name, type(value)))
3746
return False
3847
return True
3948

4049
elif self.value_type == "str":
4150
if not isinstance(value, str):
4251
print("\nattr {} should be str, but {} now\n".format(
43-
name, self.value_type))
52+
name, type(value)))
4453
return False
4554
return True
4655

4756
elif self.value_type == "strs":
4857
if not isinstance(value, list):
4958
print("\nattr {} should be list(str), but {} now\n".format(
50-
name, self.value_type))
59+
name, type(value)))
5160
return False
5261
for v in value:
5362
if not isinstance(v, str):
@@ -56,10 +65,29 @@ def is_type_valid(self, name, value):
5665
return False
5766
return True
5867

68+
elif self.value_type == "dict":
69+
if not isinstance(value, dict):
70+
print("\nattr {} should be str, but {} now\n".format(
71+
name, type(value)))
72+
return False
73+
return True
74+
75+
elif self.value_type == "dicts":
76+
if not isinstance(value, list):
77+
print("\nattr {} should be list(dist), but {} now\n".format(
78+
name, type(value)))
79+
return False
80+
for v in value:
81+
if not isinstance(v, dict):
82+
print("\nattr {} should be list(dist), but list({}) now\n".
83+
format(name, type(v)))
84+
return False
85+
return True
86+
5987
elif self.value_type == "ints":
6088
if not isinstance(value, list):
6189
print("\nattr {} should be list(int), but {} now\n".format(
62-
name, self.value_type))
90+
name, type(value)))
6391
return False
6492
for v in value:
6593
if not isinstance(v, int):
@@ -74,7 +102,7 @@ def is_type_valid(self, name, value):
74102
return False
75103

76104
def is_value_valid(self, name, value):
77-
ret = self.value_handler(value)
105+
ret = self.value_handler(name, value, self.value)
78106
return ret
79107

80108

@@ -112,38 +140,35 @@ def le_value_handler(name, value, values):
112140

113141
def register():
114142
validations = {}
115-
validations["train.workspace"] = ValueFormat("str", None, eq_value_handler)
116-
validations["train.device"] = ValueFormat("str", ["cpu", "gpu"],
117-
in_value_handler)
118-
validations["train.epochs"] = ValueFormat("int", 1, ge_value_handler)
119-
validations["train.engine"] = ValueFormat(
120-
"str", ["train", "infer", "local_cluster_train", "cluster_train"],
121-
in_value_handler)
122-
123-
requires = ["workspace", "dataset", "mode", "runner", "phase"]
124-
return validations, requires
143+
validations["workspace"] = ValueFormat("str", None, None, True)
144+
validations["mode"] = ValueFormat(None, None, None, True)
145+
validations["runner"] = ValueFormat("dicts", None, None, True)
146+
validations["phase"] = ValueFormat("dicts", None, None, True)
147+
validations["hyper_parameters"] = ValueFormat("dict", None, None, False)
148+
return validations
125149

126150

127151
def yaml_validation(config):
128-
all_checkers, require_checkers = register()
152+
all_checkers = register()
153+
154+
require_checkers = []
155+
for name, checker in all_checkers.items():
156+
if checker.required:
157+
require_checkers.append(name)
129158

130159
_config = envs.load_yaml(config)
131-
flattens = envs.flatten_environs(_config)
132160

133161
for required in require_checkers:
134-
if required not in flattens.keys():
162+
if required not in _config.keys():
135163
print("\ncan not find {} in yaml, which is required\n".format(
136164
required))
137165
return False
138166

139-
for name, flatten in flattens.items():
167+
for name, value in _config.items():
140168
checker = all_checkers.get(name, None)
141-
142-
if not checker:
143-
continue
144-
145-
ret = checker.is_valid(name, flattens)
146-
if not ret:
147-
return False
169+
if checker:
170+
ret = checker.is_valid(name, value)
171+
if not ret:
172+
return False
148173

149174
return True

models/match/dssm/synthetic_evaluate_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def reader():
3030
This function needs to be implemented by the user, based on data format
3131
"""
3232
features = line.rstrip('\n').split('\t')
33-
query = map(float, features[0].split(','))
34-
pos_doc = map(float, features[1].split(','))
33+
query = [float(feature) for feature in features[0].split(',')]
34+
pos_doc = [float(feature) for feature in features[1].split(',')]
3535
feature_names = ['query', 'doc_pos']
3636

3737
yield zip(feature_names, [query] + [pos_doc])

models/match/dssm/synthetic_reader.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ def reader():
3131
This function needs to be implemented by the user, based on data format
3232
"""
3333
features = line.rstrip('\n').split('\t')
34-
query = map(float, features[0].split(','))
35-
pos_doc = map(float, features[1].split(','))
34+
query = [float(feature) for feature in features[0].split(',')]
35+
pos_doc = [float(feature) for feature in features[1].split(',')]
3636
feature_names = ['query', 'doc_pos']
3737
neg_docs = []
3838
for i in range(len(features) - 2):
3939
feature_names.append('doc_neg_' + str(i))
40-
neg_docs.append(map(float, features[i + 2].split(',')))
40+
neg_docs.append([
41+
float(feature) for feature in features[i + 2].split(',')
42+
])
4143

4244
yield zip(feature_names, [query] + [pos_doc] + neg_docs)
4345

models/rank/afm/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def net(self, inputs, is_infer=False):
133133
attention_h) # (batch_size * (num_field*(num_field-1)/2)) * 1
134134
attention_out = fluid.layers.softmax(
135135
attention_out) # (batch_size * (num_field*(num_field-1)/2)) * 1
136-
num_interactions = self.num_field * (self.num_field - 1) / 2
136+
num_interactions = int(self.num_field * (self.num_field - 1) / 2)
137137
attention_out = fluid.layers.reshape(
138138
attention_out,
139139
shape=[-1, num_interactions,

0 commit comments

Comments
 (0)