@@ -14,6 +14,9 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
#include < set>
17
+ #include < string>
18
+ #include < vector>
19
+
17
20
#include " paddle/fluid/framework/eigen.h"
18
21
#include " paddle/fluid/framework/op_registry.h"
19
22
@@ -36,11 +39,11 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
36
39
};
37
40
38
41
void GetSegments (const int64_t * label, int length,
39
- std::vector<Segment>& segments, int num_chunk_types,
42
+ std::vector<Segment>* segments, int num_chunk_types,
40
43
int num_tag_types, int other_chunk_type, int tag_begin,
41
44
int tag_inside, int tag_end, int tag_single) const {
42
- segments. clear ();
43
- segments. reserve (length);
45
+ segments-> clear ();
46
+ segments-> reserve (length);
44
47
int chunk_start = 0 ;
45
48
bool in_chunk = false ;
46
49
int tag = -1 ;
@@ -58,7 +61,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
58
61
i - 1 , // end
59
62
prev_type,
60
63
};
61
- segments. push_back (segment);
64
+ segments-> push_back (segment);
62
65
in_chunk = false ;
63
66
}
64
67
if (ChunkBegin (prev_tag, prev_type, tag, type, other_chunk_type,
@@ -73,7 +76,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
73
76
length - 1 , // end
74
77
type,
75
78
};
76
- segments. push_back (segment);
79
+ segments-> push_back (segment);
77
80
}
78
81
}
79
82
@@ -177,8 +180,8 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
177
180
for (int i = 0 ; i < num_sequences; ++i) {
178
181
int seq_length = lod[0 ][i + 1 ] - lod[0 ][i];
179
182
EvalOneSeq (inference_data + lod[0 ][i], label_data + lod[0 ][i], seq_length,
180
- output_segments, label_segments, * num_infer_chunks_data,
181
- * num_label_chunks_data, * num_correct_chunks_data,
183
+ & output_segments, & label_segments, num_infer_chunks_data,
184
+ num_label_chunks_data, num_correct_chunks_data,
182
185
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
183
186
tag_inside, tag_end, tag_single, excluded_chunk_types);
184
187
}
@@ -197,10 +200,10 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
197
200
}
198
201
199
202
void EvalOneSeq (const int64_t * output, const int64_t * label, int length,
200
- std::vector<Segment>& output_segments,
201
- std::vector<Segment>& label_segments,
202
- int64_t & num_output_segments, int64_t & num_label_segments,
203
- int64_t & num_correct, int num_chunk_types, int num_tag_types,
203
+ std::vector<Segment>* output_segments,
204
+ std::vector<Segment>* label_segments,
205
+ int64_t * num_output_segments, int64_t * num_label_segments,
206
+ int64_t * num_correct, int num_chunk_types, int num_tag_types,
204
207
int other_chunk_type, int tag_begin, int tag_inside,
205
208
int tag_end, int tag_single,
206
209
const std::set<int >& excluded_chunk_types) const {
@@ -209,25 +212,29 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
209
212
GetSegments (label, length, label_segments, num_chunk_types, num_tag_types,
210
213
other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
211
214
size_t i = 0 , j = 0 ;
212
- while (i < output_segments. size () && j < label_segments. size ()) {
213
- if (output_segments[i] == label_segments[j] &&
214
- excluded_chunk_types.count (output_segments[i] .type ) != 1 ) {
215
- ++num_correct;
215
+ while (i < output_segments-> size () && j < label_segments-> size ()) {
216
+ if (output_segments-> at (i) == label_segments-> at (j) &&
217
+ excluded_chunk_types.count (output_segments-> at (i) .type ) != 1 ) {
218
+ ++(* num_correct) ;
216
219
}
217
- if (output_segments[i] .end < label_segments[j] .end ) {
220
+ if (output_segments-> at (i) .end < label_segments-> at (j) .end ) {
218
221
++i;
219
- } else if (output_segments[i] .end > label_segments[j] .end ) {
222
+ } else if (output_segments-> at (i) .end > label_segments-> at (j) .end ) {
220
223
++j;
221
224
} else {
222
225
++i;
223
226
++j;
224
227
}
225
228
}
226
- for (auto & segment : label_segments) {
227
- if (excluded_chunk_types.count (segment.type ) != 1 ) ++num_label_segments;
229
+ for (auto & segment : (*label_segments)) {
230
+ if (excluded_chunk_types.count (segment.type ) != 1 ) {
231
+ ++(*num_label_segments);
232
+ }
228
233
}
229
- for (auto & segment : output_segments) {
230
- if (excluded_chunk_types.count (segment.type ) != 1 ) ++num_output_segments;
234
+ for (auto & segment : (*output_segments)) {
235
+ if (excluded_chunk_types.count (segment.type ) != 1 ) {
236
+ ++(*num_output_segments);
237
+ }
231
238
}
232
239
}
233
240
};
0 commit comments