Skip to content

Commit 3f79295

Browse files
authored
[Fix] acc_mutual_info metric calculation bug (#3035)
* fix: bug in acc_mutual_info slicing; add `target_delimiter` to uncond choices * add tests
1 parent 82a9936 commit 3f79295

File tree

2 files changed

+163
-3
lines changed

2 files changed

+163
-3
lines changed

lm_eval/api/task.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,10 @@ def construct_requests(
14811481
# here mutual info refers to calculating
14821482
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
14831483
# in other words normalizing by subtracting the unconditional logprob of each choice.
1484-
aux_arguments = [("", f"{choice}") for choice in choices]
1484+
# TODO: should these be strided? will have to modify the processing in process_results if so
1485+
aux_arguments = [
1486+
("", f"{target_delimiter}{choice}") for choice in choices
1487+
]
14851488

14861489
arguments.extend(aux_arguments)
14871490

@@ -1580,11 +1583,12 @@ def process_results(self, doc, results):
15801583
):
15811584
# then we are doing mutual info.
15821585
# this stores the "dryrun" / unconditional answer loglikelihoods
1583-
lls_unconditional = lls[1::2]
1586+
# as we extend the args list with unconditional ("", continuation) pairs
1587+
lls_unconditional = lls[len(choices) :]
15841588
if len(lls_unconditional) != len(choices):
15851589
raise ValueError
15861590
# and this stores our "regular" conditional loglikelihoods
1587-
lls = lls[::2]
1591+
lls = lls[: len(choices)]
15881592

15891593
pred = np.argmax(lls)
15901594
pred_norm = np.argmax(lls / completion_len)

tests/test_metrics.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from lm_eval.api.task import ConfigurableTask, TaskConfig
2+
3+
4+
class MockConfigurableTask(ConfigurableTask):
5+
"""Mock task for testing metrics"""
6+
7+
def __init__(self):
8+
# Create a minimal config
9+
config = {
10+
"task": "test_acc_mutual_info",
11+
"output_type": "multiple_choice",
12+
"metric_list": [{"metric": "acc"}, {"metric": "acc_mutual_info"}],
13+
"doc_to_choice": ["A", "B", "C"],
14+
"doc_to_target": 1, # Correct answer is index 1 (choice "B")
15+
"target_delimiter": " ",
16+
}
17+
18+
# Initialize with minimal setup
19+
self._config = TaskConfig(**config)
20+
self.OUTPUT_TYPE = "multiple_choice"
21+
22+
# Set up required attributes
23+
self.multiple_input = 0
24+
self.multiple_target = 0
25+
26+
# Set up metrics
27+
self._metric_fn_list = {"acc": None, "acc_mutual_info": None}
28+
self._metric_fn_kwargs = {"acc": {}, "acc_mutual_info": {}}
29+
self._aggregation_list = {}
30+
self._higher_is_better = {}
31+
32+
def doc_to_choice(self, doc):
33+
return ["A", "B", "C"]
34+
35+
def doc_to_target(self, doc):
36+
return 1 # Choice "B" is correct
37+
38+
# Required abstract methods (minimal implementations)
39+
def has_training_docs(self):
40+
return False
41+
42+
def has_validation_docs(self):
43+
return False
44+
45+
def has_test_docs(self):
46+
return True
47+
48+
def download(self, **kwargs):
49+
pass
50+
51+
52+
def test_acc_mutual_info_slicing():
53+
"""Test that acc_mutual_info correctly slices conditional and unconditional loglikelihoods"""
54+
55+
task = MockConfigurableTask()
56+
57+
# Simulate loglikelihood results for 3 choices
58+
# Format: [(loglikelihood, is_greedy), ...]
59+
# First 3 are conditional P(choice|context), next 3 are unconditional P(choice)
60+
61+
# Combined results as they would come from the model
62+
# Order: conditional_1, conditional_2, conditional_3, unconditional_1, unconditional_2, unconditional_3
63+
# Conditional: [-2.0, -1.0, -3.0] - Choice B (index 1) has highest prob
64+
# Unconditional: [-2.5, -2.0, -2.5] - Choice B has higher unconditional prob too
65+
results = [
66+
(-2.0, False),
67+
(-1.0, True),
68+
(-3.0, False), # Conditional
69+
(-2.5, False),
70+
(-2.0, False),
71+
(-2.5, False),
72+
] # Unconditional
73+
74+
# Test the process_results method
75+
doc = {} # Mock document
76+
result_dict = task.process_results(doc, results)
77+
78+
# Verify that both acc and acc_mutual_info are calculated
79+
assert "acc" in result_dict
80+
assert "acc_mutual_info" in result_dict
81+
82+
# Both should be 1.0 since choice B (index 1) is correct and has highest probability
83+
assert result_dict["acc"] == 1.0, f"Expected acc=1.0, got {result_dict['acc']}"
84+
assert result_dict["acc_mutual_info"] == 1.0, (
85+
f"Expected acc_mutual_info=1.0, got {result_dict['acc_mutual_info']}"
86+
)
87+
88+
89+
def test_acc_mutual_info_different_predictions():
90+
"""Test case where conditional and mutual info predictions differ"""
91+
92+
task = MockConfigurableTask()
93+
94+
# Mutual info calculation:
95+
# Conditional: A=-1.0, B=-2.0, C=-3.0 (A wins conditionally)
96+
# Unconditional: A=-0.5, B=-2.0, C=-3.0 (A has much higher unconditional prob)
97+
# Mutual info = conditional - unconditional:
98+
# A: -1.0 - (-0.5) = -0.5
99+
# B: -2.0 - (-2.0) = 0.0 <- B wins with mutual info!
100+
# C: -3.0 - (-3.0) = 0.0
101+
102+
results = [
103+
(-1.0, True),
104+
(-2.0, False),
105+
(-3.0, False), # Conditional (A wins)
106+
(-0.5, False),
107+
(-2.0, False),
108+
(-3.0, False),
109+
] # Unconditional
110+
111+
doc = {}
112+
result_dict = task.process_results(doc, results)
113+
114+
# Regular acc should be 0.0 (A predicted, but B is correct)
115+
assert result_dict["acc"] == 0.0, f"Expected acc=0.0, got {result_dict['acc']}"
116+
117+
# Mutual info should be 1.0 (B predicted with mutual info, and B is correct)
118+
assert result_dict["acc_mutual_info"] == 1.0, (
119+
f"Expected acc_mutual_info=1.0, got {result_dict['acc_mutual_info']}"
120+
)
121+
122+
123+
def test_acc_mutual_info_without_metric():
124+
"""Test that normal behavior works when acc_mutual_info is not in metric list"""
125+
126+
# Create task without acc_mutual_info
127+
config = {
128+
"task": "test_normal",
129+
"output_type": "multiple_choice",
130+
"metric_list": [{"metric": "acc"}], # Only acc, no acc_mutual_info
131+
"doc_to_choice": ["A", "B", "C"],
132+
"doc_to_target": 1,
133+
"target_delimiter": " ",
134+
}
135+
136+
task = MockConfigurableTask()
137+
task._config = TaskConfig(**config)
138+
task._metric_fn_list = {"acc": None} # Only acc
139+
140+
# Only conditional loglikelihoods (no unconditional since acc_mutual_info not requested)
141+
results = [(-2.0, False), (-1.0, True), (-3.0, False)] # 3 choices, B wins
142+
143+
doc = {}
144+
result_dict = task.process_results(doc, results)
145+
146+
# Should only have acc, not acc_mutual_info
147+
assert "acc" in result_dict
148+
assert "acc_mutual_info" not in result_dict
149+
assert result_dict["acc"] == 1.0
150+
151+
152+
if __name__ == "__main__":
153+
test_acc_mutual_info_slicing()
154+
test_acc_mutual_info_different_predictions()
155+
test_acc_mutual_info_without_metric()
156+
print("All tests passed!")

0 commit comments

Comments
 (0)