Skip to content

Commit 0563daa

Browse files
authored
[feat] add graphwalks (#3377)
1 parent 5b144bb commit 0563daa

File tree

6 files changed

+257
-0
lines changed

6 files changed

+257
-0
lines changed

lm_eval/tasks/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ provided to the individual README.md files for each subfolder.
7373
| [global_piqa](global_piqa/README.md) | Multilingual (non-parallel) commonsense reasoning benchmark covering 116 language varieties with culturally-specific examples from 65 countries | Multiple (116 languages) **Human authored** |
7474
| [glue](glue/README.md) | General Language Understanding Evaluation benchmark to test broad language abilities. | English |
7575
| [gpqa](gpqa/README.md) | Tasks designed for general public question answering and knowledge verification. | English |
76+
| [graphwalks](graphwalks/README.md) | A multi-hop reasoning long-context benchmark | English |
7677
| [gsm8k](gsm8k/README.md) | A benchmark of grade school math problems aimed at evaluating reasoning capabilities. | English |
7778
| [groundcocoa](groundcocoa/README.md) | A benchmark evaluating the conditional and compositional reasoning of language models using a grounding task. | English |
7879
| [haerae](haerae/README.md) | Tasks focused on assessing detailed factual and historical knowledge. | Korean |

lm_eval/tasks/graphwalks/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# GraphWalks: a multi hop reasoning long context benchmark
2+
In Graphwalks, the model is given a graph represented by its edge list and asked to perform an operation.
3+
4+
### Dataset
5+
6+
HuggingFace: https://huggingface.co/datasets/openai/graphwalks
7+
8+
### Groups and Tasks
9+
10+
#### Groups
11+
12+
* `graphwalks`: Run both `graphwalks_128k` and `graphwalks_1M`
13+
14+
#### Tasks
15+
16+
* `graphwalks_128k`: Up to 128k context length
17+
* `graphwalks_1M`: Between 256k-1M context length
18+
19+
> [!NOTE]
20+
> Please note that `max_gen_toks` is set to `16384`, but non-reasoning models do not need this many tokens.
21+
22+
23+
### Checklist
24+
25+
For adding novel benchmarks/datasets to the library:
26+
* [x] Is the task an existing benchmark in the literature?
27+
* [x] Have you referenced the original paper that introduced the task?
28+
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
29+
30+
31+
If other tasks on this dataset are already supported:
32+
* [ ] Is the "Main" variant of this task clearly denoted?
33+
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
34+
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
group: graphwalks
2+
task:
3+
- graphwalks_128k
4+
- graphwalks_1M
5+
aggregate_metric_list:
6+
- metric: f1
7+
weight_by_size: true
8+
- metric: flexible_f1
9+
weight_by_size: true
10+
metadata:
11+
version: 0.0
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
task: graphwalks_128k
2+
custom_dataset: !function utils.load_dataset
3+
dataset_kwargs:
4+
data_file: graphwalks_128k_and_shorter.parquet
5+
output_type: generate_until
6+
test_split: train
7+
doc_to_text: "{{prompt}}"
8+
doc_to_target: "{{answer_nodes}}"
9+
process_results: !function utils.process_results
10+
target_delimiter: ""
11+
generation_kwargs:
12+
until:
13+
- "</s>"
14+
- "<|im_end|>"
15+
- "<|endoftext|>"
16+
max_gen_toks: 16384
17+
metric_list:
18+
- metric: f1
19+
aggregation: mean
20+
higher_is_better: true
21+
- metric: flexible_f1
22+
aggregation: mean
23+
higher_is_better: true
24+
metadata:
25+
version: 0.0
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
task: graphwalks_1M
2+
custom_dataset: !function utils.load_dataset
3+
dataset_kwargs:
4+
data_file: graphwalks_256k_to_1mil.parquet
5+
output_type: generate_until
6+
test_split: train
7+
doc_to_text: "{{prompt}}"
8+
doc_to_target: "{{answer_nodes}}"
9+
process_results: !function utils.process_results
10+
target_delimiter: ""
11+
generation_kwargs:
12+
until:
13+
- "</s>"
14+
- "<|im_end|>"
15+
- "<|endoftext|>"
16+
max_gen_toks: 16384
17+
metric_list:
18+
- metric: f1
19+
aggregation: mean
20+
higher_is_better: true
21+
- metric: flexible_f1
22+
aggregation: mean
23+
higher_is_better: true
24+
metadata:
25+
version: 0.0

lm_eval/tasks/graphwalks/utils.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import re
2+
from typing import List, Tuple
3+
4+
import datasets
5+
6+
7+
def load_dataset(**kwargs):
8+
"""
9+
Load the graphwalks dataset with specific data file.
10+
11+
Args:
12+
kwargs: Must contain 'data_file' key specifying which parquet file to load
13+
14+
Returns:
15+
Dictionary with 'train' split containing the dataset
16+
"""
17+
data_file = kwargs.get("data_file")
18+
if not data_file:
19+
raise ValueError("data_file must be specified in dataset_kwargs")
20+
21+
dataset = datasets.load_dataset(
22+
"openai/graphwalks", data_files=data_file, split="train"
23+
)
24+
return {"train": dataset}
25+
26+
27+
def extract_answer_list(response: str) -> Tuple[List[str], bool]:
28+
"""
29+
Extract the answer list from a model response.
30+
31+
Args:
32+
response: The model's generated response
33+
34+
Returns:
35+
Tuple of (list of nodes, is_error)
36+
- list of nodes: extracted node IDs
37+
- is_error: True if parsing failed, False otherwise
38+
"""
39+
# Get the very last line of the response (strip trailing newlines first)
40+
line = response.rstrip("\n").split("\n")[-1]
41+
42+
# Check if formatted correctly
43+
if "Final Answer:" not in line:
44+
return [], True
45+
46+
# Extract the list part using regex with capturing group
47+
match = re.search(r"Final Answer:\s*\[(.*)\]", line)
48+
if match:
49+
# Extract content between brackets using group(1)
50+
bracket_content = match.group(1)
51+
# Handle empty list case
52+
if not bracket_content.strip():
53+
return [], False
54+
# Split by comma and clean up whitespace and quotes
55+
result_list = [
56+
item.strip().strip("'\"")
57+
for item in bracket_content.split(",")
58+
if item.strip()
59+
]
60+
return result_list, False
61+
else:
62+
return [], True
63+
64+
65+
def extract_answer_list_flexible(response: str) -> Tuple[List[str], bool]:
66+
"""
67+
Extract the answer list from a model response (flexible version).
68+
Searches backwards through all lines to find "Final Answer:" pattern.
69+
More lenient than extract_answer_list which only checks the last line.
70+
71+
Args:
72+
response: The model's generated response
73+
74+
Returns:
75+
Tuple of (list of nodes, is_error)
76+
- list of nodes: extracted node IDs
77+
- is_error: True if parsing failed, False otherwise
78+
"""
79+
lines = response.rstrip("\n").split("\n")
80+
for line in reversed(lines):
81+
match = re.search(r"Final Answer:\s*\[(.*)\]", line)
82+
if match:
83+
# Extract content between brackets using group(1)
84+
bracket_content = match.group(1)
85+
# Handle empty list case
86+
if not bracket_content.strip():
87+
return [], False
88+
# Split by comma and clean up whitespace and quotes
89+
result_list = [
90+
item.strip().strip("'\"")
91+
for item in bracket_content.split(",")
92+
if item.strip()
93+
]
94+
return result_list, False
95+
96+
# No "Final Answer:" found anywhere
97+
return [], True
98+
99+
100+
def process_results(doc, results):
101+
"""
102+
Process results and compute set-based F1 scores.
103+
Returns both strict F1 (last line only) and flexible F1 (search all lines).
104+
105+
Args:
106+
doc: Document containing ground truth answer_nodes
107+
results: List containing model generation
108+
109+
Returns:
110+
Dictionary with f1 and flexible_f1 scores
111+
"""
112+
# Extract model response (first element of results)
113+
response = results[0]
114+
115+
# Get ground truth nodes
116+
gold_nodes = doc["answer_nodes"]
117+
118+
# Parse the response using strict extraction
119+
predicted_nodes_strict, _ = extract_answer_list(response)
120+
sampled_set_strict = set(predicted_nodes_strict)
121+
truth_set = set(gold_nodes)
122+
123+
# Calculate strict F1
124+
n_overlap_strict = len(sampled_set_strict & truth_set)
125+
n_sampled_strict = len(sampled_set_strict)
126+
n_golden = len(truth_set)
127+
128+
recall_strict = n_overlap_strict / n_golden if n_golden > 0 else 0.0
129+
precision_strict = (
130+
n_overlap_strict / n_sampled_strict if n_sampled_strict > 0 else 0.0
131+
)
132+
f1_strict = (
133+
2 * (recall_strict * precision_strict) / (recall_strict + precision_strict)
134+
if (recall_strict + precision_strict) > 0
135+
else 0.0
136+
)
137+
138+
# Parse the response using flexible extraction
139+
predicted_nodes_flexible, _ = extract_answer_list_flexible(response)
140+
sampled_set_flexible = set(predicted_nodes_flexible)
141+
142+
# Calculate flexible F1
143+
n_overlap_flexible = len(sampled_set_flexible & truth_set)
144+
n_sampled_flexible = len(sampled_set_flexible)
145+
146+
recall_flexible = n_overlap_flexible / n_golden if n_golden > 0 else 0.0
147+
precision_flexible = (
148+
n_overlap_flexible / n_sampled_flexible if n_sampled_flexible > 0 else 0.0
149+
)
150+
f1_flexible = (
151+
2
152+
* (recall_flexible * precision_flexible)
153+
/ (recall_flexible + precision_flexible)
154+
if (recall_flexible + precision_flexible) > 0
155+
else 0.0
156+
)
157+
158+
return {
159+
"f1": f1_strict,
160+
"flexible_f1": f1_flexible,
161+
}

0 commit comments

Comments
 (0)