Skip to content

Commit 1bc8de3

Browse files
committed
1. Add sequence_num as edit distance op's output
2. Fix evaluator using 'reduce_sum' op instead of 'mean' op
1 parent 0b854bd commit 1bc8de3

File tree

7 files changed

+35
-17
lines changed

7 files changed

+35
-17
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ op_library(parallel_do_op DEPS executor)
156156
# Regist multiple Kernel to pybind
157157
if (WITH_GPU)
158158
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col)
159+
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
159160
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
160161
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
161162
conv_transpose_cudnn_op.cu.cc DEPS vol2col)

paddle/operators/edit_distance_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class EditDistanceOp : public framework::OperatorWithKernel {
2525
PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
2626
PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
2727
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
28+
PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"),
29+
"Output(SequenceNum) shouldn't be null.");
2830
auto hyp_dims = ctx->GetInputDim("Hyps");
2931
auto ref_dims = ctx->GetInputDim("Refs");
3032
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
@@ -34,6 +36,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
3436
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
3537
"equal to 1.");
3638
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
39+
ctx->SetOutputDim("SequenceNum", {1});
3740
}
3841

3942
protected:
@@ -54,6 +57,7 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
5457
AddInput("Refs",
5558
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
5659
"The indices for reference strings.");
60+
AddOutput("SequenceNum", "The sequence count of current batch");
5761
AddAttr<bool>("normalized",
5862
"(bool, default false) Indicated whether to normalize "
5963
"the edit distance by the length of reference string.")

paddle/operators/edit_distance_op.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include <algorithm>
1616
#include "paddle/framework/op_registry.h"
17+
#include "paddle/operators/math/math_function.h"
1718
#include "paddle/platform/cuda_helper.h"
1819
#include "paddle/platform/gpu_info.h"
1920

@@ -72,6 +73,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
7273

7374
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
7475
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
76+
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
77+
sequence_num->mutable_data<int64_t>(ctx.GetPlace());
7578

7679
auto normalized = ctx.Attr<bool>("normalized");
7780
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
@@ -88,7 +91,11 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
8891
"Reference string %d is empty.", i);
8992
}
9093

91-
auto num_strs = hyp_lod.size() - 1;
94+
const size_t num_strs = hyp_lod.size() - 1;
95+
math::SetConstant<platform::CUDADeviceContext, int64_t> set_constant;
96+
set_constant(ctx.template device_context<platform::CUDADeviceContext>(),
97+
sequence_num, static_cast<int64_t>(num_strs));
98+
9299
out_t->Resize({static_cast<int64_t>(num_strs), 1});
93100
out_t->mutable_data<T>(ctx.GetPlace());
94101
auto out = out_t->data<T>();

paddle/operators/edit_distance_op.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616
#include <algorithm>
1717
#include "paddle/framework/eigen.h"
1818
#include "paddle/framework/op_registry.h"
19-
2019
namespace paddle {
2120
namespace operators {
2221

@@ -28,6 +27,8 @@ class EditDistanceKernel : public framework::OpKernel<T> {
2827

2928
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
3029
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
30+
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
31+
int64_t* seq_num_data = sequence_num->mutable_data<int64_t>(ctx.GetPlace());
3132

3233
auto normalized = ctx.Attr<bool>("normalized");
3334

@@ -41,6 +42,7 @@ class EditDistanceKernel : public framework::OpKernel<T> {
4142
"Reference string %d is empty.", i);
4243
}
4344
auto num_strs = hyp_lod.size() - 1;
45+
*seq_num_data = static_cast<int64_t>(num_strs);
4446

4547
out_t->Resize({static_cast<int64_t>(num_strs), 1});
4648
out_t->mutable_data<float>(ctx.GetPlace());

python/paddle/v2/fluid/evaluator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,22 +219,22 @@ def __init__(self, input, label, k=1, **kwargs):
219219

220220
self.total_error = self.create_state(
221221
dtype='float32', shape=[1], suffix='total')
222-
self.batch_num = self.create_state(
223-
dtype='float32', shape=[1], suffix='total')
224-
error = layers.edit_distance(input=input, label=label)
225-
error = layers.cast(x=error, dtype='float32')
226-
mean_error = layers.mean(x=error)
227-
layers.sums(input=[self.total_error, mean_error], out=self.total_error)
228-
const1 = layers.fill_constant(shape=[1], value=1.0, dtype="float32")
229-
layers.sums(input=[self.batch_num, const1], out=self.batch_num)
230-
self.metrics.append(mean_error)
222+
self.seq_num = self.create_state(
223+
dtype='int64', shape=[1], suffix='total')
224+
error, seq_num = layers.edit_distance(input=input, label=label)
225+
#error = layers.cast(x=error, dtype='float32')
226+
sum_error = layers.reduce_sum(error)
227+
layers.sums(input=[self.total_error, sum_error], out=self.total_error)
228+
layers.sums(input=[self.seq_num, seq_num], out=self.seq_num)
229+
self.metrics.append(sum_error)
231230

232231
def eval(self, executor, eval_program=None):
233232
if eval_program is None:
234233
eval_program = Program()
235234
block = eval_program.current_block()
236235
with program_guard(main_program=eval_program):
237236
total_error = _clone_var_(block, self.total_error)
238-
batch_num = _clone_var_(block, self.batch_num)
239-
out = layers.elementwise_div(x=total_error, y=batch_num)
237+
seq_num = _clone_var_(block, self.seq_num)
238+
seq_num = layers.cast(x=seq_num, dtype='float32')
239+
out = layers.elementwise_div(x=total_error, y=seq_num)
240240
return np.array(executor.run(eval_program, fetch_list=[out])[0])

python/paddle/v2/fluid/layers/nn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,14 +1918,16 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
19181918

19191919
# edit distance op
19201920
edit_distance_out = helper.create_tmp_variable(dtype="int64")
1921+
sequence_num = helper.create_tmp_variable(dtype="int64")
19211922
helper.append_op(
19221923
type="edit_distance",
19231924
inputs={"Hyps": [input],
19241925
"Refs": [label]},
1925-
outputs={"Out": [edit_distance_out]},
1926+
outputs={"Out": [edit_distance_out],
1927+
"SequenceNum": [sequence_num]},
19261928
attrs={"normalized": normalized})
19271929

1928-
return edit_distance_out
1930+
return edit_distance_out, sequence_num
19291931

19301932

19311933
def ctc_greedy_decoder(input, blank, name=None):

python/paddle/v2/fluid/tests/test_edit_distance_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def setUp(self):
6060

6161
num_strs = len(x1_lod) - 1
6262
distance = np.zeros((num_strs, 1)).astype("float32")
63+
sequence_num = np.array(2).astype("int64")
6364
for i in range(0, num_strs):
6465
distance[i] = Levenshtein(
6566
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
@@ -69,7 +70,7 @@ def setUp(self):
6970
distance[i] = distance[i] / len_ref
7071
self.attrs = {'normalized': normalized}
7172
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
72-
self.outputs = {'Out': distance}
73+
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
7374

7475
def test_check_output(self):
7576
self.check_output()
@@ -88,6 +89,7 @@ def setUp(self):
8889

8990
num_strs = len(x1_lod) - 1
9091
distance = np.zeros((num_strs, 1)).astype("float32")
92+
sequence_num = np.array(3).astype("int64")
9193
for i in range(0, num_strs):
9294
distance[i] = Levenshtein(
9395
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
@@ -97,7 +99,7 @@ def setUp(self):
9799
distance[i] = distance[i] / len_ref
98100
self.attrs = {'normalized': normalized}
99101
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
100-
self.outputs = {'Out': distance}
102+
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
101103

102104
def test_check_output(self):
103105
self.check_output()

0 commit comments

Comments
 (0)