Skip to content

Commit 59bcb58

Browse files
authored
Merge pull request #7759 from qingqing01/bipartite_match_op_fix
Fix bug and unit test in bipartite_match_op.
2 parents 7eb02eb + 5752892 commit 59bcb58

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

paddle/operators/bipartite_match_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace operators {
2121
using Tensor = framework::Tensor;
2222
using LoDTensor = framework::LoDTensor;
2323

24-
constexpr char kEPS = 1e-6;
25-
2624
class BipartiteMatchOp : public framework::OperatorWithKernel {
2725
public:
2826
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
4644
// The match_dist must be initialized to 0 at first.
4745
void BipartiteMatch(const Tensor& dist, int* match_indices,
4846
T* match_dist) const {
47+
constexpr T kEPS = static_cast<T>(1e-6);
4948
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
5049
int64_t row = dist.dims()[0];
5150
int64_t col = dist.dims()[1];

python/paddle/v2/fluid/tests/test_bipartite_match_op.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from op_test import OpTest
1717

1818

19-
def bipartite_match(distance, match_indices, match_dis):
19+
def bipartite_match(distance, match_indices, match_dist):
2020
"""Bipartite Matching algorithm.
2121
Arg:
2222
distance (numpy.array) : The distance of two entries with shape [M, N].
2323
match_indices (numpy.array): the matched indices from column to row
2424
with shape [1, N], it must be initialized to -1.
25-
match_dis (numpy.array): The matched distance from column to row
25+
match_dist (numpy.array): The matched distance from column to row
2626
with shape [1, N], it must be initialized to 0.
2727
"""
2828
match_pair = []
@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis):
3636
row_indices = -1 * np.ones((row, ), dtype=np.int)
3737

3838
idx = 0
39-
for i, j, dis in match_sorted:
39+
for i, j, dist in match_sorted:
4040
if idx >= row:
4141
break
42-
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
42+
if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0:
4343
match_indices[j] = i
4444
row_indices[i] = j
45-
match_dis[j] = dis
45+
match_dist[j] = dist
4646
idx += 1
4747

4848

@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod):
5555
n = len(lod) - 1
5656
m = distance.shape[1]
5757
match_indices = -1 * np.ones((n, m), dtype=np.int)
58-
match_dis = np.zeros((n, m), dtype=np.float32)
58+
match_dist = np.zeros((n, m), dtype=np.float32)
5959
for i in range(len(lod) - 1):
6060
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
61-
match_dis[i, :])
62-
return match_indices, match_dis
61+
match_dist[i, :])
62+
return match_indices, match_dist
6363

6464

6565
class TestBipartiteMatchOpForWithLoD(OpTest):
6666
def setUp(self):
6767
self.op_type = 'bipartite_match'
6868
lod = [[0, 5, 11, 23]]
69-
dis = np.random.random((23, 217)).astype('float32')
70-
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
69+
dist = np.random.random((23, 217)).astype('float32')
70+
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
7171

72-
self.inputs = {'DistMat': (dis, lod)}
72+
self.inputs = {'DistMat': (dist, lod)}
7373
self.outputs = {
7474
'ColToRowMatchIndices': (match_indices),
75-
'ColToRowMatchDis': (match_dis),
75+
'ColToRowMatchDis': (match_dist),
7676
}
7777

7878
def test_check_output(self):
@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
8383
def setUp(self):
8484
self.op_type = 'bipartite_match'
8585
lod = [[0, 8]]
86-
dis = np.random.random((8, 17)).astype('float32')
87-
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
86+
dist = np.random.random((8, 17)).astype('float32')
87+
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
8888

89-
self.inputs = {'DistMat': dis}
89+
self.inputs = {'DistMat': dist}
9090
self.outputs = {
91-
'ColToRowMatchIndices': (match_indices),
92-
'ColToRowMatchDis': (match_dis),
91+
'ColToRowMatchIndices': match_indices,
92+
'ColToRowMatchDis': match_dist,
9393
}
9494

9595
def test_check_output(self):

0 commit comments

Comments
 (0)