6363
6464class SampleLevelComputation (ABC ):
6565 @abstractmethod
66- def compute (self , model_response : ModelResponse , doc : Doc , ** kwargs ):
66+ def compute (self , doc : Doc , model_response : ModelResponse , ** kwargs ):
6767 raise NotImplementedError
6868
6969 def __str__ (self ):
@@ -444,7 +444,7 @@ def __init__(self, length_normalization: bool = False):
444444 """
445445 self .length_normalization = length_normalization
446446
447- def compute (self , model_response : ModelResponse , doc : Doc , ** kwargs ) -> float :
447+ def compute (self , doc : Doc , model_response : ModelResponse , ** kwargs ) -> float :
448448 """Mean reciprocal rank. Measures the quality of a ranking of choices (ordered by correctness).
449449
450450 Args:
@@ -1129,14 +1129,13 @@ def __init__(
11291129 raise ValueError (f"Unknown normalization function: { normalize } " )
11301130 else :
11311131 self .normalize = normalize
1132-
11331132 self .strip_strings = strip_strings
11341133
11351134 if callable (sample_scoring_function ):
11361135 self .compute_score = sample_scoring_function
11371136 self .type_exact_match = None
11381137 elif isinstance (sample_scoring_function , SampleLevelComputation ):
1139- self .score_sample = sample_scoring_function .compute
1138+ self .compute_score = sample_scoring_function .compute
11401139 self .type_exact_match = None
11411140 else :
11421141 if isinstance (sample_scoring_function , str ):
@@ -1145,11 +1144,9 @@ def __init__(
11451144 f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was { sample_scoring_function } instead."
11461145 )
11471146 self .type_exact_match = sample_scoring_function
1148- self .score_sample = self .default_sample_scoring
11491147 else :
11501148 self .type_exact_match = "full"
11511149 self .compute_score = self .default_sample_scoring
1152- self .score_sample = self .default_sample_scoring
11531150
11541151 def preprocess (self , text : str ) -> str :
11551152 if not text :
@@ -1176,19 +1173,19 @@ def name_metrics(self) -> str | list[str]:
11761173 raise NotImplementedError
11771174
11781175
1179- class AvgAtK (SamplingMetric , SampleLevelComputation ):
1180- def __init__ (self , k : int | None = None , ** kwargs ):
1181- """Sample score averages all the individual k predictions scores.
1176+ class AvgAtN (SamplingMetric , SampleLevelComputation ):
1177+ def __init__ (self , n : int | None = None , ** kwargs ):
1178+ """Sample score averages all the individual n predictions scores.
11821179
11831180 Args:
1184- k (int | None): The number of top choices to consider.
1181+ n (int | None): Number of samples to generate
11851182 **kwargs: Additional keyword arguments.
11861183 """
11871184 super ().__init__ (** kwargs )
1188- self .k = k
1189- self .attribute_must_be_set = ["k " ]
1185+ self .n = n
1186+ self .attribute_must_be_set = ["n " ]
11901187
1191- def compute (self , model_response : ModelResponse , doc : Doc , ** kwargs ):
1188+ def compute (self , doc : Doc , model_response : ModelResponse , ** kwargs ):
11921189 """Computes the metric over a list of golds and predictions for one single sample.
11931190 It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
11941191 then compares it to the gold.
@@ -1203,36 +1200,32 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
12031200 """
12041201 all_scores = []
12051202 for i in range (self .k ):
1206- all_scores .append (self .score_sample (doc , model_response [i ]))
1203+ all_scores .append (self .compute_score (doc , model_response [i ]))
12071204
12081205 avg_score = np .mean (all_scores )
12091206 return avg_score
12101207
12111208 def num_samples (self ):
1212- """Get the number of samples for this metric.
1213-
1214- Returns:
1215- int: The number of samples
1216- """
1217- return self .k
1209+ return self .n
12181210
12191211
1220- class MajAtK (SamplingMetric , SampleLevelComputation ):
1221- def __init__ (self , k : int | None = None , ** kwargs ):
1212+ class MajAtN (SamplingMetric , SampleLevelComputation ):
1213+ def __init__ (self , n : int | None = None , ** kwargs ):
12221214 """An exact match class.
12231215
12241216 Args:
1225- k (int): The number of top choices to consider.
1217+ n (int): Total number of samples to generate
12261218 **kwargs: Additional keyword arguments.
12271219 """
12281220 super ().__init__ (** kwargs )
12291221
1230- self .k = k
1231- self .attribute_must_be_set = ["k " ]
1222+ self .n = n
1223+ self .attribute_must_be_set = ["n " ]
12321224
12331225 def compute (self , doc : Doc , model_response : ModelResponse , ** kwargs ):
12341226 """Computes the metric over a list of golds and predictions for one single sample.
1235- It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones, then compares it to the gold.
1227+ It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
1228+ then compares it to the gold.
12361229
12371230 Args:
12381231 doc (Doc): The document containing gold references.
@@ -1243,39 +1236,38 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
12431236 float: Aggregated score over the current sample's items.
12441237 """
12451238 if self .k is None :
1246- raise Exception ("You did not set the value of k " )
1239+ raise Exception ("You did not set the value of n " )
12471240
12481241 golds = doc .get_golds ()
1249-
12501242 if len (golds ) > 1 :
1251- raise Exception ("Cannot compute maj@k with several golds" )
1243+ raise Exception ("Cannot compute maj@n with several golds" )
12521244
12531245 processed_choices = [self .preprocess (text = g ) for g in doc .get_golds ()]
12541246 new_doc = Doc (
12551247 choices = processed_choices ,
12561248 query = doc .query ,
1257- gold_index = list ( range ( len ( processed_choices ))) ,
1249+ gold_index = doc . gold_index ,
12581250 )
12591251 all_answers = []
1260- for pred in model_response .final_text [: self .k ]:
1252+ for pred in model_response .final_text [: self .n ]:
12611253 all_answers .append (self .preprocess (text = pred ))
12621254 majority_prediction = max (all_answers , key = all_answers .count )
12631255 new_model_response = ModelResponse (
12641256 text = [majority_prediction ],
12651257 )
1266- return self .compute_score (new_doc , new_model_response )
1258+ return self .compute_score (new_model_response , new_doc )
12671259
12681260 def num_samples (self ):
1269- return self .k
1261+ return self .n
12701262
12711263
12721264class PassAtK (SamplingMetric , SampleLevelComputation ):
12731265 def __init__ (self , k : int | None = None , n : int | None = None , ** kwargs ):
1274- """Computing pass at k
1266+ """Computing pass at k with an estimator
12751267
12761268 Args:
1277- k (int | None): Threshold for the number of successful attempts.
1278- n (int | None): Number of samples to generate.
1269+ k (int | None): Number of correct samples threshold
1270+ n (int | None): Total number of samples to generate.
12791271 **kwargs: Additional keyword arguments.
12801272 """
12811273 super ().__init__ (** kwargs )
@@ -1320,7 +1312,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
13201312 new_model_response = ModelResponse (
13211313 text = [cur_pred ],
13221314 )
1323- all_scores .append (self .score_sample (doc = new_doc , model_response = new_model_response ))
1315+ all_scores .append (self .compute_score (doc = new_doc , model_response = new_model_response ))
13241316
13251317 return self .pass_at_k (all_scores )
13261318
@@ -1348,8 +1340,8 @@ def __init__(
13481340 """Computing G-Pass@k from http://arxiv.org/abs/2412.13147
13491341
13501342 Args:
1351- k (Union[int, list[int]] | None): The number of successful attempts to be considered.
1352- n (int | None): Number of samples to generate.
1343+ k (Union[int, list[int]] | None): Number of correct samples threshold
1344+ n (int | None): Total number of samples to generate.
13531345 thresholds (list[float]): Thresholds to control successful attempts in k generate.
13541346 name_prefix (str | None): Prefix for the metric name.
13551347 **kwargs: Additional keyword arguments.
@@ -1370,7 +1362,7 @@ def k(self):
13701362 def k (self , new_val ):
13711363 self ._k = as_list (new_val )
13721364
1373- def compute (self , model_response : ModelResponse , doc : Doc , ** kwargs ) -> float :
1365+ def compute (self , doc : Doc , model_response : ModelResponse , ** kwargs ):
13741366 """Computes the metric over a list of golds and predictions for one single item with possibly many samples.
13751367 It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
13761368 then aggregates the scores over the samples using a pass@k.
@@ -1410,7 +1402,7 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
14101402 new_model_response = ModelResponse (
14111403 text = [cur_pred ],
14121404 )
1413- all_scores .append (self .score_sample (new_doc , new_model_response ))
1405+ all_scores .append (self .compute_score (new_doc , new_model_response ))
14141406
14151407 return self .g_pass_at_k (all_scores )
14161408
0 commit comments