Skip to content

Commit a4aa5de

Browse files
committed
add more tests
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
1 parent 8d66b76 commit a4aa5de

File tree

5 files changed

+238
-1
lines changed

5 files changed

+238
-1
lines changed

tests/evals/tasks/test_aime.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
3+
from skythought.evals.tasks.aime.aime_handler import AIMETaskHandler
4+
5+
6+
class MockTaskConfig:
7+
templating_parameters = {
8+
"template": "Problem: {prompt}\n\nProvide a numerical answer."
9+
}
10+
answer_key = "answer"
11+
question_key = "question"
12+
13+
14+
@pytest.mark.parametrize(
15+
"problem, response, expected",
16+
[
17+
(
18+
{
19+
"question": "Find the sum of the first 10 positive integers.",
20+
"answer": "55",
21+
},
22+
"The sum is 55",
23+
True,
24+
),
25+
(
26+
{
27+
"question": "What is the value of (3^4 - 2^5)?",
28+
"answer": "49",
29+
},
30+
"48",
31+
False,
32+
),
33+
],
34+
)
35+
def test_check_correctness(problem, response, expected):
36+
handler = AIMETaskHandler(task_config=MockTaskConfig)
37+
assert handler.check_correctness(problem, generation=response) == expected
38+
39+
40+
@pytest.mark.parametrize(
41+
"problem, expected",
42+
[
43+
(
44+
{
45+
"question": "Find the sum of the first 10 positive integers.",
46+
"answer": "4",
47+
},
48+
"Problem: Find the sum of the first 10 positive integers.\n\nProvide a numerical answer.",
49+
),
50+
],
51+
)
52+
def test_generate_prompt(problem, expected):
53+
print(problem)
54+
handler = AIMETaskHandler(task_config=MockTaskConfig)
55+
assert handler.generate_prompt(problem) == expected

tests/evals/tasks/test_amc.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
3+
from skythought.evals.tasks.amc23.amc23_handler import AMC23TaskHandler
4+
5+
6+
class MockTaskConfig:
7+
templating_parameters = {
8+
"template": "Return the answer to the following: {question}"
9+
}
10+
answer_key = "answer"
11+
question_key = "question"
12+
choices_key = "choices"
13+
14+
15+
@pytest.mark.parametrize(
16+
"problem, response, expected",
17+
[
18+
(
19+
{"question": "2+2", "answer": "4"},
20+
"5",
21+
False,
22+
),
23+
(
24+
{"question": "3* 25 percent", "answer": " 75%"},
25+
"My reply is $0.75.", # ignores dollar signs and normalizes percentages
26+
True,
27+
),
28+
],
29+
)
30+
def test_check_correctness(problem, response, expected):
31+
handler = AMC23TaskHandler(task_config=MockTaskConfig)
32+
print(handler.check_correctness(problem, generation=response))
33+
assert handler.check_correctness(problem, generation=response) == expected
34+
35+
36+
@pytest.mark.parametrize(
37+
"problem, expected",
38+
[
39+
(
40+
{"question": "What is the result of 2+2?", "answer": "4"},
41+
"Return the answer to the following: What is the result of 2+2?",
42+
),
43+
],
44+
)
45+
def test_generate_prompt(problem, expected):
46+
handler = AMC23TaskHandler(task_config=MockTaskConfig)
47+
assert handler.generate_prompt(problem) == expected

tests/evals/tasks/test_math.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55

66
class MockTaskConfig:
7-
templating_parameters = {"template": "{question}"}
7+
templating_parameters = {
8+
"template": "Return the answer to the following: {question}"
9+
}
810
answer_key = "answer"
911
question_key = "question"
1012

@@ -42,3 +44,20 @@ def test_check_correctness(
4244
):
4345
handler = MathTaskHandler(task_config=MockTaskConfig)
4446
assert handler.check_correctness(problem, generation=response) == expected
47+
48+
49+
@pytest.mark.parametrize(
50+
"problem, expected",
51+
[
52+
(
53+
{"question": "What is the result of 2+2?", "answer": "4"},
54+
"Return the answer to the following: What is the result of 2+2?",
55+
),
56+
],
57+
)
58+
def test_generate_prompt(
59+
problem,
60+
expected,
61+
):
62+
handler = MathTaskHandler(task_config=MockTaskConfig)
63+
assert handler.generate_prompt(problem) == expected

tests/evals/tasks/test_mmlu.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
3+
from skythought.evals.tasks.mmlu.mmlu_handler import MMLUTaskHandler
4+
5+
6+
class MockTaskConfig:
7+
templating_parameters = {"template": "{question}\n\nChoices:\n{choices}"}
8+
answer_key = "answer"
9+
question_key = "question"
10+
choices_key = "choices"
11+
12+
13+
@pytest.mark.parametrize(
14+
"problem, response, expected",
15+
[
16+
(
17+
{
18+
"question": "What is the capital of France?",
19+
"choices": "A) London\nB) Paris\nC) Berlin\nD) Madrid",
20+
"answer": "B",
21+
},
22+
"The answer is B) Paris",
23+
True,
24+
),
25+
(
26+
{
27+
"question": "Which element has the atomic number 1?",
28+
"choices": "A) Helium\nB) Oxygen\nC) Hydrogen\nD) Carbon",
29+
"answer": "C",
30+
},
31+
"C",
32+
False,
33+
),
34+
],
35+
)
36+
def test_check_correctness(problem, response, expected):
37+
handler = MMLUTaskHandler(task_config=MockTaskConfig)
38+
assert handler.check_correctness(problem, generation=response) == expected
39+
40+
41+
@pytest.mark.parametrize(
42+
"problem, expected",
43+
[
44+
(
45+
{"question": "What is the capital of France?", "answer": "B"},
46+
"What is the capital of France?\n\nChoices:\nA) London\nB) Paris\nC) Berlin\nD) Madrid",
47+
),
48+
],
49+
)
50+
def test_generate_prompt(problem, expected):
51+
handler = MMLUTaskHandler(task_config=MockTaskConfig)
52+
assert handler.generate_prompt(problem) == expected

tests/evals/tasks/test_mmlu_pro.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pytest
2+
3+
from skythought.evals.tasks.mmlu.mmlu_handler import MMLUProTaskHandler
4+
5+
6+
class MockTaskConfig:
7+
templating_parameters = {"template": "Question: {question}\n\nChoices:\n{choices}"}
8+
answer_key = "answer"
9+
question_key = "question"
10+
choices_key = "choices"
11+
context_key = "context"
12+
13+
14+
@pytest.mark.parametrize(
15+
"problem, response, expected",
16+
[
17+
(
18+
{
19+
"question": "What is the main function of the left ventricle?",
20+
"choices": "A) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood",
21+
"answer": "B",
22+
"answer_index": 1,
23+
},
24+
"B) Pumps blood to the body",
25+
True,
26+
),
27+
(
28+
{
29+
"question": "What does GDP stand for?",
30+
"choices": "A) Gross Domestic Product\nB) General Development Plan\nC) Global Distribution Process\nD) Geographic Data Point",
31+
"answer": "A",
32+
"answer_index": 0,
33+
},
34+
"I think it's B",
35+
False,
36+
),
37+
],
38+
)
39+
def test_check_correctness(problem, response, expected):
40+
handler = MMLUProTaskHandler(task_config=MockTaskConfig)
41+
assert handler.check_correctness(problem, generation=response) == expected
42+
43+
44+
@pytest.mark.parametrize(
45+
"problem, expected",
46+
[
47+
(
48+
{
49+
"question": "What is the main function of the left ventricle?",
50+
"choices": "A) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood",
51+
"answer": "B",
52+
"answer_index": 1,
53+
},
54+
"Question: What is the main function of the left ventricle?\n\nChoices:"
55+
"\nA) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood",
56+
),
57+
],
58+
)
59+
def test_generate_prompt(
60+
problem,
61+
expected,
62+
):
63+
handler = MMLUProTaskHandler(task_config=MockTaskConfig)
64+
assert handler.generate_prompt(problem) == expected

0 commit comments

Comments
 (0)