Skip to content

Commit 45c8a88

Browse files
authored
add crf_decoding layer (#6274)
* add crf_decoding layer * fix some typo * fix test_crf_decoding_op
1 parent e760641 commit 45c8a88

File tree

8 files changed

+61
-25
lines changed

8 files changed

+61
-25
lines changed

paddle/operators/crf_decoding_op.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@ class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
3636
"w. See more details in comments of the linear_chain_crf operator.");
3737
AddInput(
3838
"Label",
39-
"(LoDTensor, LoDTensor<int>). The ground truth with shape "
39+
"(LoDTensor, LoDTensor<int64_t>). The ground truth with shape "
4040
"[N x 1]. This input is optional. See more details in the operator's "
4141
"comments.")
4242
.AsDispensable();
43-
AddOutput("ViterbiPath",
44-
"(LoDTensor, LoDTensor<int>). The decoding results. What to "
45-
"return changes depending on whether the Input(Label) (the groud "
46-
"truth) is given. See more details in the operator's comment.");
43+
AddOutput(
44+
"ViterbiPath",
45+
"(LoDTensor, LoDTensor<int64_t>). The decoding results. What to "
46+
"return changes depending on whether the Input(Label) (the ground "
47+
"truth) is given. See more details in the operator's comment.");
4748
AddComment(R"DOC(
4849
The crf_decoding operator reads the emission feature weights and the transition
49-
freature weights learned by the linear_chain_crf operator. It implements the
50+
feature weights learned by the linear_chain_crf operator. It implements the
5051
Viterbi algorithm which is a dynamic programming algorithm for finding the most
5152
likely sequence of hidden states, called the Viterbi path, that results in a
5253
sequence of observed tags.
@@ -60,14 +61,14 @@ operator.
6061
6162
When Input(Label) is given, the crf_decoding operator returns a row vector
6263
with shape [N x 1] whose values are fixed to be 0, indicating an incorrect
63-
prediction, or 1 indicating a tag is correctly predicted. Such an ouput is the
64+
prediction, or 1 indicating a tag is correctly predicted. Such an output is the
6465
input to chunk_eval operator.
6566
6667
2. Input(Label) is not given:
6768
6869
This is the standard decoding process.
6970
70-
The crf_decoding operator returns a row vecotr with shape [N x 1] whose values
71+
The crf_decoding operator returns a row vector with shape [N x 1] whose values
7172
range from 0 to maximum tag number - 1. Each element indicates an index of a
7273
predicted tag.
7374
)DOC");

paddle/operators/crf_decoding_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
4343
const size_t level = 0;
4444
const size_t seq_num = lod[level].size() - 1;
4545

46-
int* path = decoded_path->mutable_data<int>(platform::CPUPlace());
47-
math::SetConstant<platform::CPUPlace, int>()(ctx.device_context(),
48-
decoded_path, 0);
46+
int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace());
47+
math::SetConstant<platform::CPUPlace, int64_t>()(ctx.device_context(),
48+
decoded_path, 0);
4949
for (size_t i = 0; i < seq_num; ++i) {
5050
int start_pos = static_cast<int>(lod[level][i]);
5151
int end_pos = static_cast<int>(lod[level][i + 1]);
@@ -57,7 +57,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
5757
if (label) {
5858
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
5959
"The Input(Label) should be a sequence.");
60-
const int* label_value = label->data<int>();
60+
const int64_t* label_value = label->data<int64_t>();
6161
size_t batch_size = emission_weights->dims()[0];
6262
for (size_t i = 0; i < batch_size; ++i) {
6363
path[i] = label_value[i] == path[i] ? 1 : 0;
@@ -76,7 +76,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
7676

7777
const T* x = emission_weights.data<T>();
7878
const T* w = transition_weights.data<T>();
79-
int* path = decoded_path->data<int>();
79+
int64_t* path = decoded_path->data<int64_t>();
8080

8181
// alpha is a memo table. An element alpha(k, v) records the score of the
8282
// best sequence of tags from position 1 to position k with v being the end

python/paddle/v2/fluid/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __init__(self,
237237

238238
def find_name(var_list, name):
239239
for var_name in var_list:
240-
if var_name == name:
240+
if var_list[var_name] is not None and var_name == name:
241241
return True
242242
return False
243243

python/paddle/v2/fluid/layer_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import itertools
33

4-
from framework import Variable, default_main_program, default_startup_program, \
4+
from framework import Variable, Parameter, default_main_program, default_startup_program, \
55
unique_name, dtype_is_floating
66
from paddle.v2.fluid.initializer import Constant, Xavier
77
from param_attr import ParamAttr
@@ -122,6 +122,12 @@ def create_parameter(self,
122122
return self.main_program.global_block().create_parameter(
123123
dtype=dtype, shape=shape, **attr.to_kwargs())
124124

125+
def get_parameter(self, name):
126+
param = self.main_program.global_block().var(name)
127+
if not isinstance(param, Parameter):
128+
raise ValueError("no Parameter name %s found" % name)
129+
return param
130+
125131
def create_tmp_variable(self, dtype):
126132
return self.main_program.current_block().create_var(
127133
name=unique_name(".".join([self.name, 'tmp'])),

python/paddle/v2/fluid/layers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,24 @@ def linear_chain_crf(input,
477477
return log_likelihood
478478

479479

480+
def crf_decoding(input,
481+
param_attr,
482+
label=None,
483+
main_program=None,
484+
startup_program=None):
485+
helper = LayerHelper('crf_decoding', **locals())
486+
transition = helper.get_parameter(param_attr.name)
487+
viterbi_path = helper.create_tmp_variable(dtype=helper.input_dtype())
488+
helper.append_op(
489+
type='crf_decoding',
490+
inputs={"Emission": [input],
491+
"Transition": transition,
492+
"Label": label},
493+
outputs={"ViterbiPath": [viterbi_path]})
494+
495+
return viterbi_path
496+
497+
480498
def assign(input, output, main_program=None, startup_program=None):
481499
helper = LayerHelper('assign', **locals())
482500
helper.append_op(

python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,19 @@ def main():
137137
param_attr=fluid.ParamAttr(
138138
name='crfw', learning_rate=mix_hidden_lr))
139139
avg_cost = fluid.layers.mean(x=crf_cost)
140+
140141
# TODO(qiao)
141-
# 1. add crf_decode_layer and evaluator
142-
# 2. use other optimizer and check why out will be NAN
142+
# check other optimizers and check why out will be NAN
143143
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001)
144144
sgd_optimizer.minimize(avg_cost)
145145

146+
# TODO(qiao)
147+
# add dependency track and move this config before optimizer
148+
crf_decode = fluid.layers.crf_decoding(
149+
input=feature_out,
150+
label=target,
151+
param_attr=fluid.ParamAttr(name='crfw'))
152+
146153
train_data = paddle.batch(
147154
paddle.reader.shuffle(
148155
paddle.dataset.conll05.test(), buf_size=8192),
@@ -168,7 +175,6 @@ def main():
168175
feed=feeder.feed(data),
169176
fetch_list=[avg_cost])
170177
avg_cost_val = np.array(outs[0])
171-
172178
if batch_id % 10 == 0:
173179
print("avg_cost=" + str(avg_cost_val))
174180

python/paddle/v2/fluid/tests/test_crf_decoding_op.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ def __init__(self, emission_weights, transition_weights,
2020
self.w = transition_weights[2:, :]
2121

2222
self.track = np.zeros(
23-
(seq_start_positions[-1], self.tag_num), dtype="int32")
23+
(seq_start_positions[-1], self.tag_num), dtype="int64")
2424
self.decoded_path = np.zeros(
25-
(seq_start_positions[-1], 1), dtype="int32")
25+
(seq_start_positions[-1], 1), dtype="int64")
2626

2727
def _decode_one_sequence(self, decoded_path, x):
2828
seq_len, tag_num = x.shape
2929
alpha = np.zeros((seq_len, tag_num), dtype="float64")
30-
track = np.zeros((seq_len, tag_num), dtype="int32")
30+
track = np.zeros((seq_len, tag_num), dtype="int64")
3131

3232
for i in range(tag_num):
3333
alpha[0, i] = self.a[i] + x[0, i]
@@ -125,10 +125,10 @@ def setUp(self):
125125
axis=0)
126126

127127
labels = np.random.randint(
128-
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
128+
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int64")
129129
predicted_labels = np.ones(
130-
(lod[-1][-1], 1), dtype="int32") * (TAG_NUM - 1)
131-
expected_output = (labels == predicted_labels).astype("int32")
130+
(lod[-1][-1], 1), dtype="int64") * (TAG_NUM - 1)
131+
expected_output = (labels == predicted_labels).astype("int64")
132132

133133
self.inputs = {
134134
"Emission": (emission, lod),

python/paddle/v2/fluid/tests/test_layers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import paddle.v2.fluid.layers as layers
55
import paddle.v2.fluid.nets as nets
66
from paddle.v2.fluid.framework import Program, program_guard
7+
from paddle.v2.fluid.param_attr import ParamAttr
78

89

910
class TestBook(unittest.TestCase):
@@ -132,8 +133,12 @@ def test_linear_chain_crf(self):
132133
images = layers.data(name='pixel', shape=[784], dtype='float32')
133134
label = layers.data(name='label', shape=[1], dtype='int32')
134135
hidden = layers.fc(input=images, size=128)
135-
crf = layers.linear_chain_crf(input=hidden, label=label)
136+
crf = layers.linear_chain_crf(
137+
input=hidden, label=label, param_attr=ParamAttr(name="crfw"))
138+
crf_decode = layers.crf_decoding(
139+
input=hidden, param_attr=ParamAttr(name="crfw"))
136140
self.assertNotEqual(crf, None)
141+
self.assertNotEqual(crf_decode, None)
137142

138143
print(str(program))
139144

0 commit comments

Comments
 (0)