Skip to content

Commit ae19907

Browse files
authored
Test word2vec (#10779)
1 parent 11b6473 commit ae19907

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

paddle/fluid/inference/tests/test_helper.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,7 @@ void TestInference(const std::string& dirname,
236236

237237
// Disable the profiler and print the timing information
238238
paddle::platform::DisableProfiler(
239-
paddle::platform::EventSortingKey::kDefault,
240-
"run_inference_profiler");
239+
paddle::platform::EventSortingKey::kDefault, "run_inference_profiler");
241240
paddle::platform::ResetProfiler();
242241
}
243242

python/paddle/fluid/tests/book/high-level-api/word2vec/no_test_word2vec_new_api.py renamed to python/paddle/fluid/tests/book/high-level-api/word2vec/test_word2vec_new_api.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def train_program(is_sparse):
9090
return avg_cost
9191

9292

93-
def train(use_cuda, train_program, save_path):
93+
def train(use_cuda, train_program, save_dirname):
9494
train_reader = paddle.batch(
9595
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
9696
test_reader = paddle.batch(
@@ -99,50 +99,69 @@ def train(use_cuda, train_program, save_path):
9999
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
100100

101101
def event_handler(event):
102-
if isinstance(event, fluid.EndEpochEvent):
103-
outs = trainer.test(reader=test_reader)
102+
if isinstance(event, fluid.EndStepEvent):
103+
outs = trainer.test(
104+
reader=test_reader,
105+
feed_order=['firstw', 'secondw', 'thirdw', 'forthw', 'nextw'])
104106
avg_cost = outs[0]
105107
print("loss= ", avg_cost)
106108

107-
if avg_cost < 5.0:
108-
trainer.save_params(save_path)
109-
return
109+
if avg_cost < 10.0:
110+
trainer.save_params(save_dirname)
111+
trainer.stop()
112+
110113
if math.isnan(avg_cost):
111114
sys.exit("got NaN loss, training failed.")
112115

113116
trainer = fluid.Trainer(
114-
train_program, fluid.optimizer.SGD(learning_rate=0.001), place=place)
117+
train_func=train_program,
118+
optimizer=fluid.optimizer.SGD(learning_rate=0.001),
119+
place=place)
120+
115121
trainer.train(
116-
reader=train_reader, num_epochs=1, event_handler=event_handler)
122+
reader=train_reader,
123+
num_epochs=1,
124+
event_handler=event_handler,
125+
feed_order=['firstw', 'secondw', 'thirdw', 'forthw', 'nextw'])
117126

118127

119-
def infer(use_cuda, inference_program, save_path):
128+
def infer(use_cuda, inference_program, save_dirname=None):
120129
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
121130
inferencer = fluid.Inferencer(
122-
infer_func=inference_program, param_path=save_path, place=place)
131+
infer_func=inference_program, param_path=save_dirname, place=place)
123132

124133
lod = [0, 1]
125134
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
126135
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
127136
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
128137
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
129138

130-
result = inferencer.infer({
131-
'firstw': first_word,
132-
'secondw': second_word,
133-
'thirdw': third_word,
134-
'forthw': fourth_word
135-
})
139+
result = inferencer.infer(
140+
{
141+
'firstw': first_word,
142+
'secondw': second_word,
143+
'thirdw': third_word,
144+
'forthw': fourth_word
145+
},
146+
return_numpy=False)
136147
print(np.array(result[0]))
137148

138149

139150
def main(use_cuda, is_sparse):
140151
if use_cuda and not fluid.core.is_compiled_with_cuda():
141152
return
142153

143-
save_path = "word2vec.params"
144-
train(use_cuda, partial(train_program, is_sparse), save_path)
145-
infer(use_cuda, partial(inference_program, is_sparse), save_path)
154+
save_path = "word2vec.inference.model"
155+
156+
train(
157+
use_cuda=use_cuda,
158+
train_program=partial(train_program, is_sparse),
159+
save_dirname=save_path)
160+
161+
infer(
162+
use_cuda=use_cuda,
163+
inference_program=partial(inference_program, is_sparse),
164+
save_dirname=save_path)
146165

147166

148167
if __name__ == '__main__':

0 commit comments

Comments
 (0)