-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathutils.py
More file actions
66 lines (59 loc) · 2.07 KB
/
utils.py
File metadata and controls
66 lines (59 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
def calc_hammingDist(B1, B2):
q = B2.shape[1]
if len(B1.shape) < 2:
B1 = B1.unsqueeze(0)
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
return distH
def calc_map_k(qB, rB, query_L, retrieval_L, k=None):
# qB: {-1,+1}^{mxq}
# rB: {-1,+1}^{nxq}
# query_L: {0,1}^{mxl}
# retrieval_L: {0,1}^{nxl}
num_query = query_L.shape[0]
map = 0
if k is None:
k = retrieval_L.shape[0]
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0)
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = calc_hammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
total = min(k, int(tsum))
count = torch.arange(1, total + 1).type(torch.float32)
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
if tindex.is_cuda:
count = count.cuda()
map = map + torch.mean(count / tindex)
map = map / num_query
return map
if __name__ == '__main__':
qB = torch.Tensor([[1, -1, 1, 1],
[-1, -1, -1, 1],
[1, 1, -1, 1],
[1, 1, 1, -1]])
rB = torch.Tensor([[1, -1, 1, -1],
[-1, -1, 1, -1],
[-1, -1, 1, -1],
[1, 1, -1, -1],
[-1, 1, -1, -1],
[1, 1, -1, 1]])
query_L = torch.Tensor([[0, 1, 0, 0],
[1, 1, 0, 0],
[1, 0, 0, 1],
[0, 1, 0, 1]])
retrieval_L = torch.Tensor([[1, 0, 0, 1],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 0, 1, 0]])
map = calc_map_k(qB, rB, query_L, retrieval_L)
print(map)