16
16
from op_test import OpTest
17
17
18
18
19
- def bipartite_match (distance , match_indices , match_dis ):
19
+ def bipartite_match (distance , match_indices , match_dist ):
20
20
"""Bipartite Matching algorithm.
21
21
Arg:
22
22
distance (numpy.array) : The distance of two entries with shape [M, N].
23
23
match_indices (numpy.array): the matched indices from column to row
24
24
with shape [1, N], it must be initialized to -1.
25
- match_dis (numpy.array): The matched distance from column to row
25
+ match_dist (numpy.array): The matched distance from column to row
26
26
with shape [1, N], it must be initialized to 0.
27
27
"""
28
28
match_pair = []
@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis):
36
36
row_indices = - 1 * np .ones ((row , ), dtype = np .int )
37
37
38
38
idx = 0
39
- for i , j , dis in match_sorted :
39
+ for i , j , dist in match_sorted :
40
40
if idx >= row :
41
41
break
42
- if match_indices [j ] == - 1 and row_indices [i ] == - 1 and dis > 0 :
42
+ if match_indices [j ] == - 1 and row_indices [i ] == - 1 and dist > 0 :
43
43
match_indices [j ] = i
44
44
row_indices [i ] = j
45
- match_dis [j ] = dis
45
+ match_dist [j ] = dist
46
46
idx += 1
47
47
48
48
@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod):
55
55
n = len (lod ) - 1
56
56
m = distance .shape [1 ]
57
57
match_indices = - 1 * np .ones ((n , m ), dtype = np .int )
58
- match_dis = np .zeros ((n , m ), dtype = np .float32 )
58
+ match_dist = np .zeros ((n , m ), dtype = np .float32 )
59
59
for i in range (len (lod ) - 1 ):
60
60
bipartite_match (distance [lod [i ]:lod [i + 1 ], :], match_indices [i , :],
61
- match_dis [i , :])
62
- return match_indices , match_dis
61
+ match_dist [i , :])
62
+ return match_indices , match_dist
63
63
64
64
65
65
class TestBipartiteMatchOpForWithLoD (OpTest ):
66
66
def setUp (self ):
67
67
self .op_type = 'bipartite_match'
68
68
lod = [[0 , 5 , 11 , 23 ]]
69
- dis = np .random .random ((23 , 217 )).astype ('float32' )
70
- match_indices , match_dis = batch_bipartite_match (dis , lod [0 ])
69
+ dist = np .random .random ((23 , 217 )).astype ('float32' )
70
+ match_indices , match_dist = batch_bipartite_match (dist , lod [0 ])
71
71
72
- self .inputs = {'DistMat' : (dis , lod )}
72
+ self .inputs = {'DistMat' : (dist , lod )}
73
73
self .outputs = {
74
74
'ColToRowMatchIndices' : (match_indices ),
75
- 'ColToRowMatchDis' : (match_dis ),
75
+ 'ColToRowMatchDis' : (match_dist ),
76
76
}
77
77
78
78
def test_check_output (self ):
@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
83
83
def setUp (self ):
84
84
self .op_type = 'bipartite_match'
85
85
lod = [[0 , 8 ]]
86
- dis = np .random .random ((8 , 17 )).astype ('float32' )
87
- match_indices , match_dis = batch_bipartite_match (dis , lod [0 ])
86
+ dist = np .random .random ((8 , 17 )).astype ('float32' )
87
+ match_indices , match_dist = batch_bipartite_match (dist , lod [0 ])
88
88
89
- self .inputs = {'DistMat' : dis }
89
+ self .inputs = {'DistMat' : dist }
90
90
self .outputs = {
91
- 'ColToRowMatchIndices' : ( match_indices ) ,
92
- 'ColToRowMatchDis' : ( match_dis ) ,
91
+ 'ColToRowMatchIndices' : match_indices ,
92
+ 'ColToRowMatchDis' : match_dist ,
93
93
}
94
94
95
95
def test_check_output (self ):
0 commit comments