Skip to content

Commit 1961470

Browse files
authored
Add inference example and unit-test for word2vec chapter (#8206)
* Add unit-test and example * Fix type error * Fix unit test cases * Fix init error for cudaplace * Change unit-test options
1 parent 4b62fcd commit 1961470

File tree

4 files changed

+147
-12
lines changed

4 files changed

+147
-12
lines changed

paddle/inference/tests/book/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ function(inference_test TARGET_NAME)
2525
endfunction(inference_test)
2626

2727
inference_test(fit_a_line)
28-
inference_test(recognize_digits ARGS mlp)
2928
inference_test(image_classification ARGS vgg resnet)
3029
inference_test(label_semantic_roles)
31-
inference_test(rnn_encoder_decoder)
30+
inference_test(recognize_digits ARGS mlp)
3231
inference_test(recommender_system)
32+
inference_test(rnn_encoder_decoder)
3333
inference_test(understand_sentiment)
34+
inference_test(word2vec)

paddle/inference/tests/book/test_helper.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ template <typename Place, bool IsCombined = false>
9191
void TestInference(const std::string& dirname,
9292
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
9393
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
94-
// 1. Define place, executor, scope and inference_program
94+
// 1. Define place, executor, scope
9595
auto place = Place();
9696
auto executor = paddle::framework::Executor(place);
9797
auto* scope = new paddle::framework::Scope();
@@ -101,7 +101,8 @@ void TestInference(const std::string& dirname,
101101
if (IsCombined) {
102102
// All parameters are saved in a single file.
103103
// Hard-coding the file names of program and parameters in unittest.
104-
// Users are free to specify different filename.
104+
// Users are free to specify different filename
105+
// (provided: the filenames are changed in the python api as well: io.py)
105106
std::string prog_filename = "__model_combined__";
106107
std::string param_filename = "__params_combined__";
107108
inference_program = paddle::inference::Load(executor,
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "gflags/gflags.h"
17+
#include "test_helper.h"
18+
19+
DEFINE_string(dirname, "", "Directory of the inference model.");
20+
21+
TEST(inference, word2vec) {
22+
if (FLAGS_dirname.empty()) {
23+
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
24+
}
25+
26+
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
27+
std::string dirname = FLAGS_dirname;
28+
29+
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
30+
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
31+
32+
paddle::framework::LoDTensor first_word, second_word, third_word, fourth_word;
33+
paddle::framework::LoD lod{{0, 1}};
34+
int64_t dict_size = 2072; // Hard-coding the size of dictionary
35+
36+
SetupLoDTensor(first_word, lod, static_cast<int64_t>(0), dict_size);
37+
SetupLoDTensor(second_word, lod, static_cast<int64_t>(0), dict_size);
38+
SetupLoDTensor(third_word, lod, static_cast<int64_t>(0), dict_size);
39+
SetupLoDTensor(fourth_word, lod, static_cast<int64_t>(0), dict_size);
40+
41+
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
42+
cpu_feeds.push_back(&first_word);
43+
cpu_feeds.push_back(&second_word);
44+
cpu_feeds.push_back(&third_word);
45+
cpu_feeds.push_back(&fourth_word);
46+
47+
paddle::framework::LoDTensor output1;
48+
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
49+
cpu_fetchs1.push_back(&output1);
50+
51+
// Run inference on CPU
52+
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
53+
LOG(INFO) << output1.lod();
54+
LOG(INFO) << output1.dims();
55+
56+
#ifdef PADDLE_WITH_CUDA
57+
paddle::framework::LoDTensor output2;
58+
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
59+
cpu_fetchs2.push_back(&output2);
60+
61+
// Run inference on CUDA GPU
62+
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
63+
LOG(INFO) << output2.lod();
64+
LOG(INFO) << output2.dims();
65+
66+
CheckError<float>(output1, output2);
67+
#endif
68+
}

python/paddle/v2/fluid/tests/book/test_word2vec.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# # Licensed under the Apache License, Version 2.0 (the "License");
43
# you may not use this file except in compliance with the License.
54
# You may obtain a copy of the License at
65
#
@@ -16,14 +15,67 @@
1615
import paddle.v2.fluid as fluid
1716
import unittest
1817
import os
18+
import numpy as np
1919
import math
2020
import sys
2121

2222

23-
def main(use_cuda, is_sparse, parallel):
24-
if use_cuda and not fluid.core.is_compiled_with_cuda():
23+
def create_random_lodtensor(lod, place, low, high):
24+
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
25+
res = fluid.LoDTensor()
26+
res.set(data, place)
27+
res.set_lod([lod])
28+
return res
29+
30+
31+
def infer(use_cuda, save_dirname=None):
32+
if save_dirname is None:
2533
return
2634

35+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
36+
exe = fluid.Executor(place)
37+
38+
# Use fluid.io.load_inference_model to obtain the inference program desc,
39+
# the feed_target_names (the names of variables that will be feeded
40+
# data using feed operators), and the fetch_targets (variables that
41+
# we want to obtain data from using fetch operators).
42+
[inference_program, feed_target_names,
43+
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
44+
45+
word_dict = paddle.dataset.imikolov.build_dict()
46+
dict_size = len(word_dict) - 1
47+
48+
# Setup input, by creating 4 words, and setting up lod required for
49+
# lookup_table_op
50+
lod = [0, 1]
51+
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
52+
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
53+
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
54+
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size)
55+
56+
assert feed_target_names[0] == 'firstw'
57+
assert feed_target_names[1] == 'secondw'
58+
assert feed_target_names[2] == 'thirdw'
59+
assert feed_target_names[3] == 'forthw'
60+
61+
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
62+
# and results will contain a list of data corresponding to fetch_targets.
63+
results = exe.run(inference_program,
64+
feed={
65+
feed_target_names[0]: first_word,
66+
feed_target_names[1]: second_word,
67+
feed_target_names[2]: third_word,
68+
feed_target_names[3]: fourth_word
69+
},
70+
fetch_list=fetch_targets,
71+
return_numpy=False)
72+
print(results[0].lod())
73+
np_data = np.array(results[0])
74+
print("Inference Shape: ", np_data.shape)
75+
print("Inference results: ", np_data)
76+
77+
78+
def train(use_cuda, is_sparse, parallel, save_dirname):
2779
PASS_NUM = 100
2880
EMBED_SIZE = 32
2981
HIDDEN_SIZE = 256
@@ -67,7 +119,7 @@ def __network__(words):
67119
act='softmax')
68120
cost = fluid.layers.cross_entropy(input=predict_word, label=words[4])
69121
avg_cost = fluid.layers.mean(x=cost)
70-
return avg_cost
122+
return avg_cost, predict_word
71123

72124
word_dict = paddle.dataset.imikolov.build_dict()
73125
dict_size = len(word_dict)
@@ -79,13 +131,13 @@ def __network__(words):
79131
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
80132

81133
if not parallel:
82-
avg_cost = __network__(
134+
avg_cost, predict_word = __network__(
83135
[first_word, second_word, third_word, forth_word, next_word])
84136
else:
85137
places = fluid.layers.get_places()
86138
pd = fluid.layers.ParallelDo(places)
87139
with pd.do():
88-
avg_cost = __network__(
140+
avg_cost, predict_word = __network__(
89141
map(pd.read_input, [
90142
first_word, second_word, third_word, forth_word, next_word
91143
]))
@@ -113,13 +165,25 @@ def __network__(words):
113165
feed=feeder.feed(data),
114166
fetch_list=[avg_cost])
115167
if avg_cost_np[0] < 5.0:
168+
if save_dirname is not None:
169+
fluid.io.save_inference_model(save_dirname, [
170+
'firstw', 'secondw', 'thirdw', 'forthw'
171+
], [predict_word], exe)
116172
return
117173
if math.isnan(float(avg_cost_np[0])):
118174
sys.exit("got NaN loss, training failed.")
119175

120176
raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0]))
121177

122178

179+
def main(use_cuda, is_sparse, parallel):
180+
if use_cuda and not fluid.core.is_compiled_with_cuda():
181+
return
182+
save_dirname = "word2vec.inference.model"
183+
train(use_cuda, is_sparse, parallel, save_dirname)
184+
infer(use_cuda, save_dirname)
185+
186+
123187
FULL_TEST = os.getenv('FULL_TEST',
124188
'0').lower() in ['true', '1', 't', 'y', 'yes', 'on']
125189
SKIP_REASON = "Only run minimum number of tests in CI server, to make CI faster"
@@ -142,7 +206,8 @@ def __impl__(*args, **kwargs):
142206
with fluid.program_guard(prog, startup_prog):
143207
main(use_cuda=use_cuda, is_sparse=is_sparse, parallel=parallel)
144208

145-
if use_cuda and is_sparse and parallel:
209+
# run only 2 cases: use_cuda is either True or False
210+
if is_sparse == False and parallel == False:
146211
fn = __impl__
147212
else:
148213
# skip the other test when on CI server

0 commit comments

Comments
 (0)