|
26 | 26 |
|
27 | 27 | import logging
|
28 | 28 | import os
|
29 |
| -from typing import Callable, Literal |
| 29 | +from typing import Callable, Literal, Union |
30 | 30 |
|
31 | 31 | import nltk
|
32 | 32 | import numpy as np
|
@@ -1055,3 +1055,118 @@ def compute_score(self, pred: str, gold: str) -> int:
|
1055 | 1055 | if self.type_exact_match == "suffix":
|
1056 | 1056 | return 1 if pred.endswith(gold) else 0
|
1057 | 1057 | return 1 if gold == pred else 0
|
| 1058 | + |
| 1059 | + |
| 1060 | +class PassAtK: |
| 1061 | + def __init__( |
| 1062 | + self, |
| 1063 | + k: int, |
| 1064 | + n: int = None, |
| 1065 | + normalize_gold: Callable = None, |
| 1066 | + normalize_pred: Callable = None, |
| 1067 | + strip_strings: bool = False, |
| 1068 | + sample_scoring_function: Union[Callable[[str, str], float], str] = None, |
| 1069 | + ): |
| 1070 | + """Computing pass at k |
| 1071 | +
|
| 1072 | + Args: |
| 1073 | + k (int): Threshold for the number of successful attempts. |
| 1074 | + n (int): Number of samples to generate |
| 1075 | + normalize_gold (callable, optional): Function to use to normalize the reference strings. |
| 1076 | + Defaults to None if no normalization is applied. |
| 1077 | + normalize_pred (callable, optional): Function to use to normalize the predicted strings. |
| 1078 | + Defaults to None if no normalization is applied. |
| 1079 | + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. |
| 1080 | + sample_scoring_function (callable or str, optional): Function to use to score each sample. |
| 1081 | + Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1) |
| 1082 | + a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full". |
| 1083 | + `prefix` checks if the prediction starts with the gold, |
| 1084 | + `suffix` if the prediction ends with the gold, |
| 1085 | + `full` if the prediction and gold are equal |
| 1086 | + """ |
| 1087 | + self.k = k |
| 1088 | + self.n = n |
| 1089 | + self.normalize_gold = normalize_gold |
| 1090 | + self.normalize_pred = normalize_pred |
| 1091 | + self.strip_strings = strip_strings |
| 1092 | + |
| 1093 | + # Managed the logic of the per prediction of sample scoring |
| 1094 | + if callable(sample_scoring_function): |
| 1095 | + self.score_sample = sample_scoring_function |
| 1096 | + self.type_exact_match = None |
| 1097 | + else: |
| 1098 | + if isinstance(sample_scoring_function, str): |
| 1099 | + if sample_scoring_function not in ["prefix", "suffix", "full"]: |
| 1100 | + raise ValueError( |
| 1101 | + f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." |
| 1102 | + ) |
| 1103 | + self.type_exact_match = sample_scoring_function |
| 1104 | + else: |
| 1105 | + self.type_exact_match = "full" |
| 1106 | + self.score_sample = self.default_sample_scoring |
| 1107 | + |
| 1108 | + def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]: |
| 1109 | + """Computes the metric over a list of golds and predictions for one single item with possibly many samples. |
| 1110 | + It applies normalisation (if needed) to model prediction and gold, computes their per prediction score, |
| 1111 | + then aggregates the scores over the samples using a pass@k. |
| 1112 | +
|
| 1113 | + Args: |
| 1114 | + golds (list[str]): Reference targets |
| 1115 | + predictions (list[str]): k predicted strings |
| 1116 | +
|
| 1117 | + Returns: |
| 1118 | + float: Aggregated score over the current sample's items. |
| 1119 | + """ |
| 1120 | + if len(golds) > 1: |
| 1121 | + raise Exception("Cannot compute pass@k with several golds") |
| 1122 | + |
| 1123 | + if self.n is None: |
| 1124 | + self.n = len(predictions) |
| 1125 | + logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.") |
| 1126 | + elif len(predictions) < self.n: |
| 1127 | + logger.warning(f"Number of predictions is less than {self.n} for pass@k.") |
| 1128 | + |
| 1129 | + gold = self.get_processed_gold(golds[0]) |
| 1130 | + |
| 1131 | + all_scores = [] |
| 1132 | + for pred in predictions[: self.n]: |
| 1133 | + cur_pred = self.get_processed_pred(pred=pred) |
| 1134 | + all_scores.append(self.score_sample(cur_pred, gold)) |
| 1135 | + |
| 1136 | + return self.pass_at_k(all_scores) |
| 1137 | + |
| 1138 | + def get_processed_gold(self, gold: str) -> float: |
| 1139 | + if self.strip_strings: |
| 1140 | + gold = gold.strip() |
| 1141 | + |
| 1142 | + if self.normalize_gold: |
| 1143 | + gold = self.normalize_gold(gold) |
| 1144 | + |
| 1145 | + return gold |
| 1146 | + |
| 1147 | + def get_processed_pred(self, pred: str) -> float: |
| 1148 | + if not pred: |
| 1149 | + return "" |
| 1150 | + |
| 1151 | + if self.strip_strings: |
| 1152 | + pred = pred.strip() |
| 1153 | + |
| 1154 | + if self.normalize_pred: |
| 1155 | + pred = self.normalize_pred(pred) |
| 1156 | + |
| 1157 | + return pred |
| 1158 | + |
| 1159 | + def default_sample_scoring(self, pred: str, gold: str) -> int: |
| 1160 | + if self.type_exact_match == "prefix": |
| 1161 | + return 1 if pred.startswith(gold) else 0 |
| 1162 | + if self.type_exact_match == "suffix": |
| 1163 | + return 1 if pred.endswith(gold) else 0 |
| 1164 | + return 1 if gold == pred else 0 |
| 1165 | + |
| 1166 | + def pass_at_k(self, all_scores: list[int]) -> float: |
| 1167 | + """Algo from https://arxiv.org/pdf/2107.03374""" |
| 1168 | + c: int = all_scores.count(1) |
| 1169 | + if self.n - c < self.k: |
| 1170 | + return 1.0 |
| 1171 | + |
| 1172 | + return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1)) |
0 commit comments