Skip to content

Commit af9a330

Browse files
committed
test=develop
1 parent 014e50c commit af9a330

File tree

4 files changed

+152
-130
lines changed

4 files changed

+152
-130
lines changed

paddle/fluid/framework/selected_rows.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ class SelectedRows {
121121
int64_t AutoGrownIndex(int64_t key, bool auto_grown);
122122

123123
void SyncIndex();
124-
124+
/*
125+
* @brief Get complete Dims before
126+
*/
125127
DDim GetCompleteDims() const {
126128
std::vector<int64_t> dims = vectorize(value_->dims());
127129
dims[0] = height_;
@@ -136,7 +138,7 @@ class SelectedRows {
136138
std::unordered_map<int64_t, int64_t>
137139
id_to_index_; // should not be used when ids has duplicate member
138140
std::unique_ptr<Tensor> value_{nullptr};
139-
int64_t height_;
141+
int64_t height_; // height indicates the underline tensor's height
140142
std::unique_ptr<RWLock> rwlock_{nullptr};
141143
};
142144

paddle/fluid/operators/hierarchical_sigmoid_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
145145
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
146146
"Input(Preout) should not be null.");
147147
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
148-
"Output(W@Grad should not be null.)");
149-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
148+
"Output(W@Grad should not be null.");
149+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
150+
"Output(X@Grad should not be null.");
150151
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
151152
ctx->SetOutputDim(framework::GradVarName("Bias"),
152153
ctx->GetInputDim("Bias"));

paddle/fluid/operators/hierarchical_sigmoid_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
191191
framework::Vector<int64_t> real_rows = cal_rows(path);
192192
auto* w_grad =
193193
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
194-
195194
w_grad->set_rows(real_rows);
196195
// build ids -> rows index map
197196
w_grad->SyncIndex();
197+
w_grad->set_height(w->dims()[0]);
198198
auto* w_grad_value = w_grad->mutable_value();
199199
framework::DDim temp_dim(w->dims());
200200
set(temp_dim, 0, real_rows.size());

python/paddle/fluid/tests/unittests/test_hsigmoid_op.py

Lines changed: 144 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -140,148 +140,167 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
140140
return pre_output, out
141141

142142

143-
# class TestHSigmoidOp(OpTest):
144-
# def setUp(self):
145-
# self.op_type = "hierarchical_sigmoid"
146-
# num_classes = 6
147-
# feature_size = 8
148-
# batch_size = 4
149-
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2
150-
# w = np.random.random(
151-
# (num_classes - 1, feature_size)).astype("float32") * 2
152-
# label = np.random.randint(0, num_classes, (batch_size, 1))
153-
# bias = np.random.random((1, num_classes - 1)).astype("float32")
154-
# self.attrs = {'num_classes': num_classes, 'is_sparse': False}
155-
# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
156-
# pre_output, out = hsigmoid(x, w, label, bias, num_classes)
157-
# self.outputs = {'PreOut': pre_output, 'Out': out}
158-
159-
# def test_check_output(self):
160-
# self.check_output()
161-
162-
# def test_check_grad(self):
163-
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
164-
165-
# class TestHSigmoidOpSparse(OpTest):
166-
# def setUp(self):
167-
# self.op_type = "hierarchical_sigmoid"
168-
# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
169-
# feature_size = 8
170-
# batch_size = 4
171-
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2
172-
# w = np.random.random(
173-
# (num_classes - 1, feature_size)).astype("float32") * 2
174-
# label = np.array([0, 1, 4, 5])
175-
# ptable = np.array(
176-
# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
177-
# (0, 2, -1, -1,
178-
# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
179-
# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
180-
# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
181-
# bias = np.random.random((1, num_classes - 1)).astype("float32")
182-
# self.attrs = {'num_classes': num_classes, 'is_sparse': True}
183-
# self.inputs = {
184-
# 'X': x,
185-
# 'W': w,
186-
# 'PTable': ptable,
187-
# 'PCode': pcode,
188-
# 'Label': label,
189-
# 'Bias': bias
190-
# }
191-
# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
192-
# bias, num_classes)
193-
# self.outputs = {'PreOut': pre_output, 'Out': out}
194-
195-
# def test_check_output(self):
196-
# print("checking output in CostumTree")
197-
# self.check_output()
198-
199-
200-
class TestHSigmoidOpWithSparseGrad():
201-
def hs_net_conf(self):
202-
emb = fluid.layers.data(name="x", shape=[3], dtype='int64')
143+
class TestHSigmoidOp(OpTest):
144+
def setUp(self):
145+
self.op_type = "hierarchical_sigmoid"
146+
num_classes = 6
147+
feature_size = 8
148+
batch_size = 4
149+
x = np.random.random((batch_size, feature_size)).astype("float32") * 2
150+
w = np.random.random(
151+
(num_classes - 1, feature_size)).astype("float32") * 2
152+
label = np.random.randint(0, num_classes, (batch_size, 1))
153+
bias = np.random.random((1, num_classes - 1)).astype("float32")
154+
self.attrs = {'num_classes': num_classes, 'is_sparse': False}
155+
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
156+
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
157+
self.outputs = {'PreOut': pre_output, 'Out': out}
158+
159+
def test_check_output(self):
160+
self.check_output()
161+
162+
def test_check_grad(self):
163+
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
164+
165+
166+
class TestHSigmoidOpSparse(OpTest):
167+
def setUp(self):
168+
self.op_type = "hierarchical_sigmoid"
169+
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
170+
feature_size = 8
171+
batch_size = 4
172+
x = np.random.random((batch_size, feature_size)).astype("float32")
173+
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
174+
label = np.array([0, 1, 4, 5])
175+
ptable = np.array(
176+
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
177+
(0, 2, -1, -1,
178+
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
179+
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
180+
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
181+
bias = np.random.random((1, num_classes - 1)).astype("float32")
182+
self.attrs = {'num_classes': num_classes, 'is_sparse': True}
183+
self.inputs = {
184+
'X': x,
185+
'W': w,
186+
'PTable': ptable,
187+
'PCode': pcode,
188+
'Label': label,
189+
'Bias': bias
190+
}
191+
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
192+
bias, num_classes)
193+
self.outputs = {'PreOut': pre_output, 'Out': out}
194+
195+
def test_check_output(self):
196+
print("checking output in CostumTree")
197+
self.check_output()
198+
199+
200+
class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
201+
def hs_net_conf(self, is_sparse):
202+
input_word = fluid.layers.data(name="x", shape=[1], dtype='int64')
203203
ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64')
204204
pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64')
205205
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
206-
data_list = [emb, ptable, pcode, label]
206+
207+
data_list = [input_word, ptable, pcode, label]
208+
209+
emb = fluid.layers.embedding(
210+
input=input_word,
211+
is_sparse=False,
212+
size=[3, 3],
213+
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
214+
scale=1 / math.sqrt(3))))
215+
207216
cost = fluid.layers.hsigmoid(
208217
input=emb,
209-
label=predict_word,
210-
non_leaf_num=4,
218+
label=label,
219+
non_leaf_num=3,
211220
ptable=ptable,
212221
pcode=pcode,
213222
is_costum=True,
214-
is_sparse=True)
223+
is_sparse=is_sparse)
215224

216225
avg_cost = fluid.layers.reduce_mean(cost)
217226

218227
return avg_cost, data_list
219228

220-
def test_training_test(self):
221-
print("im here")
222-
w = np.arange(12).reshape(4, 3)
223-
x = np.ones((2, 3))
224-
ptable = np.array([(1, 2, -1), (1, 2, -1)])
225-
pcode = np.array([(1, 0, -1), (0, 0, -1)])
226-
label = np.array([(1, 4)])
227-
228-
loss, data_list = hs_net_conf()
229-
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
230-
optimizer.minimize(loss)
231-
232-
main_program = fluid.default_main_program()
233-
234-
place = fluid.CPUPlace()
235-
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
236-
data_name_list = [var.name for var in data_list]
237-
exe = fluid.Executor(place)
238-
exe.run(fluid.default_startup_program())
239-
for pass_id in range(args.num_passes):
229+
def training_test(self, is_sparse):
230+
with fluid.program_guard(fluid.Program(), fluid.Program()):
231+
start_up = fluid.default_startup_program()
232+
start_up.random_seed = 1 # Fix random seed
233+
x = np.arange(6).reshape(6)
234+
ptable = np.array([(1, 2, -1), (1, 2, -1)])
235+
pcode = np.array([(1, 0, -1), (0, 0, -1)])
236+
label = np.array([1, 4])
237+
238+
loss, data_list = self.hs_net_conf(is_sparse)
239+
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
240+
optimizer.minimize(loss)
241+
242+
main_program = fluid.default_main_program()
243+
# print("main program: {program}".format{program=str(main_program)})
244+
place = fluid.CPUPlace()
245+
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
246+
exe = fluid.Executor(place)
247+
248+
exe.run(start_up)
249+
result = list()
240250
for i in range(10):
241-
data = [w, x[i % 2], ptable[i % 2], pcode[i % 2], label[i % 2]]
251+
data = [([[x[i % 2]]], [list(ptable[i % 2])],
252+
[list(pcode[i % 2])], [label[i % 2]])]
253+
242254
loss_val = exe.run(main_program,
243255
feed=feeder.feed(data),
244256
fetch_list=[loss])
245-
print("loss is: {loss}".format(loss=loss))
246-
247-
248-
# class TestHSigmoidOpWithCostumTree(OpTest):
249-
# def setUp(self):
250-
# self.op_type = "hierarchical_sigmoid"
251-
# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
252-
# feature_size = 8
253-
# batch_size = 4
254-
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2
255-
# w = np.random.random(
256-
# (num_classes - 1, feature_size)).astype("float32") * 2
257-
# label = np.array([0, 1, 4, 5])
258-
# ptable = np.array(
259-
# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
260-
# (0, 2, -1, -1,
261-
# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
262-
# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
263-
# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
264-
# bias = np.random.random((1, num_classes - 1)).astype("float32")
265-
# self.attrs = {'num_classes': num_classes, 'is_sparse': False}
266-
# self.inputs = {
267-
# 'X': x,
268-
# 'W': w,
269-
# 'PTable': ptable,
270-
# 'PCode': pcode,
271-
# 'Label': label,
272-
# 'Bias': bias
273-
# }
274-
# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
275-
# bias, num_classes)
276-
# self.outputs = {'PreOut': pre_output, 'Out': out}
277-
278-
# def test_check_output(self):
279-
# print("checking output in CostumTree")
280-
# self.check_output()
281-
282-
# def test_check_grad(self):
283-
# print("checking outputGrad in CostumTree")
284-
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
257+
result.append(loss_val)
258+
return result
259+
260+
def test_hs_grad_with_sparse(self):
261+
dense_result = self.training_test(is_sparse=False)
262+
sparse_result = self.training_test(is_sparse=True)
263+
assert (dense_result == sparse_result)
264+
265+
266+
class TestHSigmoidOpWithCostumTree(OpTest):
267+
def setUp(self):
268+
self.op_type = "hierarchical_sigmoid"
269+
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
270+
feature_size = 8
271+
batch_size = 4
272+
x = np.random.random((batch_size, feature_size)).astype("float32") * 2
273+
w = np.random.random(
274+
(num_classes - 1, feature_size)).astype("float32") * 2
275+
label = np.array([0, 1, 4, 5])
276+
ptable = np.array(
277+
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
278+
(0, 2, -1, -1,
279+
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
280+
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
281+
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
282+
bias = np.random.random((1, num_classes - 1)).astype("float32")
283+
self.attrs = {'num_classes': num_classes, 'is_sparse': False}
284+
self.inputs = {
285+
'X': x,
286+
'W': w,
287+
'PTable': ptable,
288+
'PCode': pcode,
289+
'Label': label,
290+
'Bias': bias
291+
}
292+
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
293+
bias, num_classes)
294+
self.outputs = {'PreOut': pre_output, 'Out': out}
295+
296+
def test_check_output(self):
297+
print("checking output in CostumTree")
298+
self.check_output()
299+
300+
def test_check_grad(self):
301+
print("checking outputGrad in CostumTree")
302+
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
303+
285304

286305
if __name__ == '__main__':
287306
unittest.main()

0 commit comments

Comments
 (0)