Skip to content

Commit ccfb774

Browse files
authored
Add MRR metric (#5097)
* add MRR to metrics * add MRR to metrics * add MRR to metrics * add MRR to metrics * add MRR to metrics
1 parent 920ce29 commit ccfb774

File tree

3 files changed

+147
-5
lines changed

3 files changed

+147
-5
lines changed

paddlenlp/metrics/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .perplexity import Perplexity
16-
from .chunk import ChunkEvaluator
1715
from .bleu import BLEU, BLEUForDuReader
18-
from .rouge import RougeL, RougeLForDuReader, RougeN, Rouge1, Rouge2
19-
from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman, MultiLabelsMetric
16+
from .chunk import ChunkEvaluator
2017
from .distinct import Distinct
21-
from .sighan import DetectionF1, CorrectionF1
18+
from .glue import AccuracyAndF1, Mcc, MultiLabelsMetric, PearsonAndSpearman
19+
from .mrr import MRR
20+
from .perplexity import Perplexity
21+
from .rouge import Rouge1, Rouge2, RougeL, RougeLForDuReader, RougeN
22+
from .sighan import CorrectionF1, DetectionF1
2223
from .span import SpanEvaluator

paddlenlp/metrics/mrr.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
from sklearn.metrics import pairwise_distances
17+
18+
__all__ = ["MRR"]
19+
20+
21+
class MRR:
22+
"""
23+
MRR - Mean Reciprocal Rank, is a popular metric for recommend system
24+
and other retrival task. The higher mrr is, the better performance of
25+
model in retrival task.
26+
27+
Args:
28+
distance: which algorithm to use to get distance of embeddings, for example: "cosine", "euclidean"
29+
30+
"""
31+
32+
def __init__(self, distance="cosine"):
33+
super().__init__()
34+
self.distance = distance
35+
36+
def reset_distance(self, distance):
37+
"""
38+
change the algorithm of calculating distance, need to be supported of sklearn.metrics.pairwise_distance
39+
"""
40+
self.distance = distance
41+
42+
def compute_matrix_mrr(self, labels, embeddings):
43+
"""
44+
A function which can calculate the distance of one embedding to other embeddings
45+
in the matrix, and then it can find the most similar embedding's index to calculate
46+
the mrr metric for this one embedding. After getting all the embeddings' mrr metric,
47+
a mean pool is used to get the final mrr metric for input matrix.
48+
49+
Param:
50+
- labels(np.array): label matrix, shape=[size, ]
51+
- embeddings(np.array): embedding matrix, shape=[size, emb_dim]
52+
53+
Return:
54+
mrr metric for input embedding matrix.
55+
"""
56+
matrix_size = labels.shape[0]
57+
if labels.shape[0] != embeddings.shape[0]:
58+
raise Exception("label and embedding matrix must have same size at dim=0 !")
59+
row_mrr = [] # mrr metric for each embedding of matrix
60+
for i in range(0, matrix_size):
61+
emb, label = embeddings[i, :], labels[i]
62+
dists = pairwise_distances(emb.reshape(1, -1), embeddings, metric=self.distance).reshape(-1)
63+
ranks_ids = np.argsort(dists)[1:]
64+
ranks = (labels[ranks_ids] == label).astype(int)
65+
ranks_nonzero_ids = ranks.nonzero()[0]
66+
row_mrr.append(1.0 / (1 + ranks_nonzero_ids[0]) if ranks_nonzero_ids.size else 0.0)
67+
mrr = np.mean(row_mrr) # user mean value as final mrr metric for the matrix.
68+
return mrr

tests/metrics/test_mrr.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
import unittest
17+
18+
import numpy as np
19+
20+
from paddlenlp.metrics import MRR
21+
from tests.common_test import CommonTest
22+
23+
24+
class TestMRR(CommonTest):
25+
def setUp(self):
26+
self.distance = "cosine"
27+
self.mrr = MRR(distance=self.distance)
28+
self.label_num = 10
29+
self.label_shape = (20,)
30+
self.embedding_shape = (20, 128)
31+
32+
def get_random_case(self):
33+
labels = np.random.randint(0, self.label_num, size=self.label_shape).astype("int64")
34+
embeddings = np.random.uniform(0.1, 1.0, self.embedding_shape).astype("float64")
35+
all_distance = ["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"]
36+
distance = random.choice(all_distance)
37+
return labels, embeddings, distance, all_distance
38+
39+
def get_true_mrr_case(self):
40+
labels = np.array([1, 2, 1]).astype("int64")
41+
embeddings = np.array(
42+
[
43+
# cosine similarity: 1,2 => 0.991; 1,3=>0.851; 2,3=>0.912
44+
[1.0, 2.0, 3.0],
45+
[1.0, 2.0, 4.0],
46+
[1.0, 100.0, 1000.0],
47+
]
48+
)
49+
distance = "cosine"
50+
true_mrr = (1.0 / 2 + 0 + 1.0 / 2) / 3
51+
return labels, embeddings, distance, true_mrr
52+
53+
def test_reset_distance(self):
54+
_, _, distance, _ = self.get_random_case()
55+
self.mrr.reset_distance(distance)
56+
self.check_output_equal(self.mrr.distance, distance)
57+
58+
def test_compute_matrix_mrr(self):
59+
step = 100
60+
for i in range(step):
61+
labels, embeddings, distance, _ = self.get_random_case()
62+
self.mrr.reset_distance(distance)
63+
self.mrr.compute_matrix_mrr(labels, embeddings)
64+
65+
def test_compute_true_mrr(self):
66+
labels, embeddings, distance, true_mrr = self.get_true_mrr_case()
67+
self.mrr.reset_distance(distance)
68+
mrr = self.mrr.compute_matrix_mrr(labels, embeddings)
69+
self.check_output_equal(mrr, true_mrr)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)