Skip to content

Commit 60e85da

Browse files
authored
add Mbpp instruct (#2995)
* feat: add mbpp_instruct * fix: update generation_kwargs to use an empty until list * fix: correct predictions formatting in pass_at_1 function * fix: improve code block extraction by checking first without opening backticks * fix mbpp `pass_at_1`
1 parent d57e3d6 commit 60e85da

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
task: mbpp_instruct
2+
dataset_path: google-research-datasets/mbpp
3+
dataset_name: full
4+
unsafe_code: true
5+
output_type: generate_until
6+
test_split: test
7+
doc_to_text: "You are an expert Python programmer, and here is your task:\n{{text}}\nYour code should pass these tests:\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}"
8+
doc_to_target: "{% if is_fewshot is defined %}{{code}}\n```{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}"
9+
gen_prefix: "\n```python\n"
10+
target_delimiter: ""
11+
metric_list:
12+
- metric: !function utils.pass_at_1
13+
aggregation: mean
14+
higher_is_better: true
15+
filter_list:
16+
- name: "extract_code"
17+
filter:
18+
- function: "custom"
19+
filter_fn: !function utils.build_predictions
20+
generation_kwargs:
21+
max_gen_toks: 256
22+
until: []
23+
do_sample: false
24+
num_fewshot: 3
25+
fewshot_config:
26+
sampler: first_n
27+
samples: !function utils.list_fewshot_samples
28+
metadata:
29+
version: 1.0
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
include: mbpp_instruct.yaml
2+
task: mbpp_plus_instruct
3+
dataset_path: evalplus/mbppplus
4+
dataset_name: null
5+
doc_to_text: "{{prompt if prompt is defined else text}} Your code should satisfy the following assertion:\n{{test_list[0]}}"
6+
doc_to_target: "{{test_list[0]}}"
7+
gen_prefix: "Here is a solution to this programming problem:\n```python\n"
8+
num_fewshot: 0
9+
generation_kwargs:
10+
max_gen_toks: 1024
11+
until: []
12+
do_sample: false

lm_eval/tasks/mbpp/utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import re
2+
from typing import Union
3+
14
import evaluate as hf_evaluate
25

36

@@ -12,14 +15,41 @@
1215
raise e
1316

1417

15-
def pass_at_1(references, predictions):
18+
def pass_at_1(
19+
references: Union[str, list[str]], predictions: Union[str, list[list[str]]]
20+
) -> float:
21+
if isinstance(references, str):
22+
references = [references]
23+
if isinstance(predictions[0], str):
24+
predictions = [[p] for p in predictions]
25+
print(f"References: {references}")
26+
print(f"Predictions: {predictions}")
1627
return pass_at_k.compute(
1728
references=references,
18-
predictions=[predictions],
29+
predictions=predictions,
1930
k=[1],
2031
)[0]["pass@1"]
2132

2233

34+
def extract_code_blocks(text: str) -> str:
35+
# Pattern to match ```...``` blocks
36+
pattern = r"```(?:\w+)?\n?(.*?)\n?```"
37+
# (+ ```) as we add the opening "```python" to the gen_prefix
38+
matches = re.findall(pattern, r"```" + text, re.DOTALL)
39+
# if no matches, try to match ```...``` blocks (after removing the language)
40+
if not matches:
41+
text_without_lang = re.sub(r"```python", "```", text)
42+
matches = re.findall(pattern, text_without_lang, re.DOTALL)
43+
if not matches:
44+
return ""
45+
else:
46+
return matches[0]
47+
48+
49+
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
50+
return [[extract_code_blocks(r) for r in resp] for resp in resps]
51+
52+
2353
def list_fewshot_samples():
2454
return [
2555
{

0 commit comments

Comments
 (0)