Skip to content

Commit c469334

Browse files
committed
polish python code and comment, test=develop
1 parent 87648f8 commit c469334

File tree

4 files changed

+88
-73
lines changed

4 files changed

+88
-73
lines changed

paddle/fluid/operators/hierarchical_sigmoid_op.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ template <typename DeviceContext, typename T>
4747
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
4848
public:
4949
void Compute(const framework::ExecutionContext& ctx) const override {
50-
auto in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
51-
auto w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
50+
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
51+
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
5252
auto* path = ctx.Input<framework::LoDTensor>("PTable");
5353
auto* code = ctx.Input<framework::LoDTensor>("PathCode");
54-
auto label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
54+
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
5555
auto* bias = ctx.Input<framework::LoDTensor>("Bias");
5656
auto* out = ctx.Output<framework::LoDTensor>("Out");
5757
auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
@@ -114,8 +114,8 @@ template <typename DeviceContext, typename T>
114114
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
115115
public:
116116
void Compute(const framework::ExecutionContext& ctx) const override {
117-
auto in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
118-
auto w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
117+
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
118+
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
119119
auto* path = ctx.Input<framework::LoDTensor>("PTable");
120120
auto* code = ctx.Input<framework::LoDTensor>("PathCode");
121121
auto* bias = ctx.Input<framework::LoDTensor>("Bias");
@@ -124,9 +124,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
124124
bool is_sparse = ctx.Attr<bool>("is_sparse");
125125
auto& dev_ctx = ctx.template device_context<DeviceContext>();
126126
math::SetConstant<DeviceContext, T> zero;
127-
auto label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
128-
auto pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut"));
129-
auto out_grad = detail::Ref(
127+
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
128+
auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut"));
129+
auto& out_grad = detail::Ref(
130130
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")));
131131
framework::LoDTensor pre_out_grad;
132132

python/paddle/fluid/layers/nn.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4589,37 +4589,40 @@ def hsigmoid(input,
45894589
bias_attr=None,
45904590
name=None,
45914591
non_leaf_num=None,
4592-
ptable=None,
4593-
pcode=None,
4594-
is_costum=False,
4592+
path_table=None,
4593+
path_code=None,
4594+
is_custom=False,
45954595
is_sparse=False):
45964596
"""
45974597
The hierarchical sigmoid operator is used to accelerate the training
45984598
process of language model. This operator organizes the classes into a
4599-
complete binary tree, each leaf node represents a class(a word) and each
4599+
complete binary tree, or you can use is_custom to pass your own tree to
4600+
implement hierarchical. Each leaf node represents a class(a word) and each
46004601
internal node acts as a binary classifier. For each word there's a unique
46014602
path from root to it's leaf node, hsigmoid calculate the cost for each
46024603
internal node on the path, and sum them to get a total cost. hsigmoid can
46034604
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
46044605
represents the size of word dict.
46054606
4606-
Refer to `Hierarchical Probabilistic Neural Network Language Model
4607+
Using default tree you can Refer to `Hierarchical Probabilistic Neural Network Language Model
46074608
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
46084609
4610+
And if you want to use the costumed tree by set 'is_custom' as true you may need to do following things first:
4611+
1. using your word dict to build a binary tree, each leaf node should be an word of your word dict
4612+
2. build a dict to store word_id -> word's leaf to root path, we call it path_table.
4613+
3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code
4614+
means label of each binary classification, using 1 indicate true, 0 indicate false.
4615+
4. now, each word should has its path and code along the path, you can pass a batch of path and code
4616+
related to the same batch of inputs.
4617+
4618+
46094619
Args:
46104620
input (Variable): The input tensor variable with shape
46114621
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
46124622
and :math:`D` is the feature size.
46134623
label (Variable): The tensor variable contains labels of training data.
46144624
It's a tensor with shape is :math:`[N \\times 1]`.
46154625
num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set
4616-
non_leaf_num: this defines the number of non-leaf nodes in costumed tree
4617-
ptable: (Variable|None) this variable can store each batch of samples' path to root,
4618-
it should be in leaf -> root order
4619-
ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like
4620-
structure and each element in this array is indexes in parent nodes' Weight Matrix.
4621-
pcode: (Variable|None) this variable can store each batch of samples' code,
4622-
each code consist with every code of parent nodes. it should be in leaf -> root order
46234626
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
46244627
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
46254628
will create ParamAttr as param_attr. If the Initializer of the param_attr
@@ -4631,8 +4634,15 @@ def hsigmoid(input,
46314634
is not set, the bias is initialized zero. Default: None.
46324635
name (str|None): A name for this layer(optional). If set None, the layer
46334636
will be named automatically. Default: None.
4634-
is_costum: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is
4635-
set you need to set ptable/pcode/non_leaf_num, otherwise num_classes should be set
4637+
non_leaf_num: this defines the number of non-leaf nodes in costumed tree
4638+
path_table: (Variable|None) this variable can store each batch of samples' path to root,
4639+
it should be in leaf -> root order
4640+
path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like
4641+
structure and each element in this array is indexes in parent nodes' Weight Matrix.
4642+
path_code: (Variable|None) this variable can store each batch of samples' code,
4643+
each code consist with every code of parent nodes. it should be in leaf -> root order
4644+
is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is
4645+
set you need to set path_table/path_code/non_leaf_num, otherwise num_classes should be set
46364646
is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient
46374647
of W and input will be sparse.
46384648
@@ -4653,22 +4663,22 @@ def hsigmoid(input,
46534663
out = helper.create_variable_for_type_inference(dtype)
46544664
pre_out = helper.create_variable_for_type_inference(dtype)
46554665
dim = input.shape[1]
4656-
if ((num_classes is None) or (num_classes < 2)) and (not is_costum):
4666+
if ((num_classes is None) or (num_classes < 2)) and (not is_custom):
46574667
raise ValueError(
46584668
"num_classes must not be less than 2 with default tree")
46594669

4660-
if (is_costum) and (pcode is None):
4661-
raise ValueError("pcode should not be None with costum tree")
4662-
elif (is_costum) and (ptable is None):
4663-
raise ValueError("ptable should not be None with costum tree")
4664-
elif (is_costum) and (non_leaf_num is None):
4670+
if (is_custom) and (path_code is None):
4671+
raise ValueError("path_code should not be None with costum tree")
4672+
elif (is_custom) and (path_table is None):
4673+
raise ValueError("path_table should not be None with costum tree")
4674+
elif (is_custom) and (non_leaf_num is None):
46654675
raise ValueError("non_leaf_num should not be None with costum tree")
46664676
else:
46674677
pass
46684678

46694679
weights = None
46704680

4671-
if not is_costum:
4681+
if not is_custom:
46724682
weights = helper.create_parameter(
46734683
attr=helper.param_attr,
46744684
shape=[num_classes - 1, dim],
@@ -4683,12 +4693,12 @@ def hsigmoid(input,
46834693
inputs = {
46844694
"X": input,
46854695
"W": weights,
4686-
"PTable": ptable,
4687-
"PathCode": pcode,
4696+
"PTable": path_table,
4697+
"PathCode": path_code,
46884698
"Label": label
46894699
}
46904700
if helper.bias_attr:
4691-
if not is_costum:
4701+
if not is_custom:
46924702
bias = helper.create_parameter(
46934703
attr=helper.bias_attr,
46944704
shape=[num_classes - 1, 1],

python/paddle/fluid/tests/unittests/test_hsigmoid_op.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def cal_bit(self, bit):
4343

4444

4545
class CodeTableWithCustomTree(object):
46-
def __init__(self, ptable, pcode, index):
47-
self.ptable_ = ptable
48-
self.pcode_ = pcode
46+
def __init__(self, path_table, path_code, index):
47+
self.ptable_ = path_table
48+
self.pcode_ = path_code
4949
self.index_ = index
5050

5151
def cal_index(self, bit):
@@ -102,23 +102,24 @@ def hsigmoid(x, w, label, bias, num_classes):
102102
return pre_output, out
103103

104104

105-
def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
105+
def hsigmoidWithCustomTree(x, w, path_table, path_code, label, bias,
106+
num_classes):
106107
batch_size = x.shape[0]
107-
code_length = len(ptable[0])
108+
code_length = len(path_table[0])
108109
code_table = [0 for _ in range(code_length)]
109110
# init pre_out with shape [N, code_length]
110111
pre_output = np.zeros((batch_size, code_length))
111112
pre_sum = np.zeros((batch_size, 1))
112113
out = np.zeros((batch_size, 1)).astype("float32")
113114
if isinstance(bias, np.ndarray):
114115
for i in range(batch_size):
115-
code_table = CodeTableWithCustomTree(ptable, pcode, i)
116+
code_table = CodeTableWithCustomTree(path_table, path_code, i)
116117
length = code_table.get_length()
117118
for j in range(length):
118119
idx = code_table.cal_index(j)
119120
pre_output[i][j] += bias[idx][0]
120121
for i in range(batch_size):
121-
code_table = CodeTableWithCustomTree(ptable, pcode, i)
122+
code_table = CodeTableWithCustomTree(path_table, path_code, i)
122123
length = code_table.get_length()
123124
for j in range(length):
124125
idx = code_table.cal_index(j)
@@ -127,7 +128,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
127128
pre_output = np.clip(pre_output, -40.0, 40.0)
128129
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
129130
for i in range(batch_size):
130-
code_table = CodeTableWithCustomTree(ptable, pcode, i)
131+
code_table = CodeTableWithCustomTree(path_table, path_code, i)
131132
length = code_table.get_length()
132133
sum = 0.0
133134
for j in range(length):
@@ -173,24 +174,24 @@ def setUp(self):
173174
x = np.random.random((batch_size, feature_size)).astype("float32")
174175
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
175176
label = np.array([0, 1, 4, 5])
176-
ptable = np.array(
177+
path_table = np.array(
177178
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
178179
(0, 2, -1, -1,
179180
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
180-
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
181+
path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
181182
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
182183
bias = np.random.random((num_classes - 1, 1)).astype("float32")
183184
self.attrs = {'num_classes': num_classes, 'is_sparse': True}
184185
self.inputs = {
185186
'X': x,
186187
'W': w,
187-
'PTable': ptable,
188-
'PathCode': pcode,
188+
'PTable': path_table,
189+
'PathCode': path_code,
189190
'Label': label,
190191
'Bias': bias
191192
}
192-
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
193-
bias, num_classes)
193+
pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code,
194+
label, bias, num_classes)
194195
self.outputs = {'PreOut': pre_output, 'Out': out}
195196

196197
def test_check_output(self):
@@ -200,11 +201,13 @@ def test_check_output(self):
200201
class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
201202
def hs_net_conf(self, is_sparse):
202203
input_word = fluid.layers.data(name="x", shape=[1], dtype='int64')
203-
ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64')
204-
pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64')
204+
path_table = fluid.layers.data(
205+
name='path_table', shape=[3], dtype='int64')
206+
path_code = fluid.layers.data(
207+
name='path_code', shape=[3], dtype='int64')
205208
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
206209

207-
data_list = [input_word, ptable, pcode, label]
210+
data_list = [input_word, path_table, path_code, label]
208211

209212
emb = fluid.layers.embedding(
210213
input=input_word,
@@ -218,9 +221,9 @@ def hs_net_conf(self, is_sparse):
218221
label=label,
219222
bias_attr=True,
220223
non_leaf_num=3,
221-
ptable=ptable,
222-
pcode=pcode,
223-
is_costum=True,
224+
path_table=path_table,
225+
path_code=path_code,
226+
is_custom=True,
224227
is_sparse=is_sparse)
225228

226229
avg_cost = fluid.layers.reduce_mean(cost)
@@ -232,8 +235,8 @@ def training_test(self, is_sparse):
232235
start_up = fluid.default_startup_program()
233236
start_up.random_seed = 1 # Fix random seed
234237
x = np.arange(6).reshape(6)
235-
ptable = np.array([(1, 2, -1), (1, 2, -1)])
236-
pcode = np.array([(1, 0, -1), (0, 0, -1)])
238+
path_table = np.array([(1, 2, -1), (1, 2, -1)])
239+
path_code = np.array([(1, 0, -1), (0, 0, -1)])
237240
label = np.array([1, 4])
238241

239242
loss, data_list = self.hs_net_conf(is_sparse)
@@ -248,8 +251,8 @@ def training_test(self, is_sparse):
248251
exe.run(start_up)
249252
result = list()
250253
for i in range(10):
251-
data = [([[x[i % 2]]], [list(ptable[i % 2])],
252-
[list(pcode[i % 2])], [label[i % 2]])]
254+
data = [([[x[i % 2]]], [list(path_table[i % 2])],
255+
[list(path_code[i % 2])], [label[i % 2]])]
253256

254257
loss_val = exe.run(main_program,
255258
feed=feeder.feed(data),
@@ -273,24 +276,24 @@ def setUp(self):
273276
w = np.random.random(
274277
(num_classes - 1, feature_size)).astype("float32") * 2
275278
label = np.array([0, 1, 4, 5])
276-
ptable = np.array(
279+
path_table = np.array(
277280
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
278281
(0, 2, -1, -1,
279282
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
280-
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
283+
path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
281284
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
282285
bias = np.random.random((num_classes - 1, 1)).astype("float32")
283286
self.attrs = {'num_classes': num_classes, 'is_sparse': False}
284287
self.inputs = {
285288
'X': x,
286289
'W': w,
287-
'PTable': ptable,
288-
'PathCode': pcode,
290+
'PTable': path_table,
291+
'PathCode': path_code,
289292
'Label': label,
290293
'Bias': bias
291294
}
292-
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
293-
bias, num_classes)
295+
pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code,
296+
label, bias, num_classes)
294297
self.outputs = {'PreOut': pre_output, 'Out': out}
295298

296299
def test_check_output(self):
@@ -310,26 +313,26 @@ def setUp(self):
310313
w = np.random.random(
311314
(num_classes - 1, feature_size)).astype("float32") * 2
312315
label = np.array([0, 1, 4, 5])
313-
ptable = np.array(
316+
path_table = np.array(
314317
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
315318
(0, 2, -1, -1,
316319
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
317-
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
320+
path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
318321
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
319322
# bias = np.random.random((num_classes - 1, 1)).astype("float32")
320323
self.attrs = {'num_classes': num_classes, 'is_sparse': False}
321324
self.inputs = {
322325
'X': x,
323326
'W': w,
324-
'PTable': ptable,
325-
'PathCode': pcode,
327+
'PTable': path_table,
328+
'PathCode': path_code,
326329
'Label': label,
327330
}
328331
pre_output, out = hsigmoidWithCustomTree(
329332
x=x,
330333
w=w,
331-
ptable=ptable,
332-
pcode=pcode,
334+
path_table=path_table,
335+
path_code=path_code,
333336
label=label,
334337
bias=None,
335338
num_classes=num_classes)

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,18 @@ def test_hsigmoid(self):
190190
with program_guard(program2):
191191
x2 = layers.data(name='x2', shape=[4, 8], dtype='float32')
192192
y2 = layers.data(name='y2', shape=[4], dtype='int64')
193-
ptable = layers.data(name='ptable', shape=[4, 6], dtype='int64')
194-
pcode = layers.data(name='pcode', shape=[4, 6], dtype='int64')
193+
path_table = layers.data(
194+
name='path_table', shape=[4, 6], dtype='int64')
195+
path_code = layers.data(
196+
name='path_code', shape=[4, 6], dtype='int64')
195197
self.assertIsNotNone(
196198
layers.hsigmoid(
197199
input=x2,
198200
label=y2,
199201
non_leaf_num=6,
200-
ptable=ptable,
201-
pcode=pcode,
202-
is_costum=True))
202+
path_table=path_table,
203+
path_code=path_code,
204+
is_custom=True))
203205
print(str(program2))
204206

205207
def test_sequence_expand(self):

0 commit comments

Comments
 (0)