Skip to content

Commit c01605f

Browse files
committed
Lint
Signed-off-by: Claudio Spiess <[email protected]>
1 parent 0272c66 commit c01605f

File tree

3 files changed

+12
-37
lines changed

3 files changed

+12
-37
lines changed

src/pdl/optimize/mbpp_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def select(self, iterable):
1414

1515

1616
class MBPPDataset(dict):
17-
def __init__(self) -> None:
17+
def __init__(self, dataset_path: str) -> None:
1818
self.mbpp_plus = get_mbpp_plus()
1919
self.dataset_hash = get_mbpp_plus_hash()
2020

@@ -24,9 +24,7 @@ def __init__(self) -> None:
2424
MBPP_OUTPUT_NOT_NONE_TASKS,
2525
)
2626

27-
self.mbpp = load_from_disk(
28-
"../prompt-declaration-language-merge/var/mbpp_trajectified",
29-
).rename_column(
27+
self.mbpp = load_from_disk(dataset_path).rename_column(
3028
"code",
3129
"canonical_solution",
3230
)

src/pdl/optimize/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
dataset = fever
122122
TrialThread = FEVEREvaluator
123123
elif config.benchmark == "evalplus":
124-
dataset = MBPPDataset()
124+
dataset = MBPPDataset(args.dataset_path)
125125
TrialThread = MBPPEvaluator
126126
else:
127127
print(f"Unknown benchmark: {config.benchmark}")

tests/test_optimizer.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pprint import pprint
33

44
from datasets import Dataset, DatasetDict
5+
from pytest import skip
56

67
from pdl.optimize.config_parser import OptimizationConfig
78
from pdl.optimize.fever_evaluator import FEVEREvaluator
@@ -731,7 +732,9 @@ def run_optimizer_mbpp(pattern, num_demonstrations=0):
731732
},
732733
)
733734

734-
mbpp_dataset = MBPPDataset()
735+
mbpp_dataset = MBPPDataset(
736+
"../prompt-declaration-language-merge/var/mbpp_trajectified",
737+
)
735738

736739
optim = PDLOptimizer(
737740
pdl_path=Path("examples/optimizer/evalplus.pdl"),
@@ -748,23 +751,25 @@ def run_optimizer_mbpp(pattern, num_demonstrations=0):
748751
pprint(result)
749752

750753

754+
@skip("API access not available in CI")
751755
def test_gsm8k_zeroshot_cot():
752756
run_optimizer_gsm8k("cot")
753757

754758

755-
def test_gsm8k_zeroshot_react():
759+
@skip("API access not available in CI")
760+
def test_gsm8k_fiveshot_react():
756761
run_optimizer_gsm8k("react", num_demonstrations=5)
757762

758763

759-
def test_gsm8k_zeroshot_rewoo():
764+
def test_gsm8k_fiveshot_rewoo():
760765
run_optimizer_gsm8k("rewoo", num_demonstrations=5)
761766

762767

763768
def test_fever_zeroshot_cot():
764769
run_optimizer_fever("cot")
765770

766771

767-
def test_fever_zeroshot_react():
772+
def test_fever_fiveshot_react():
768773
run_optimizer_fever("react", num_demonstrations=5)
769774

770775

@@ -778,31 +783,3 @@ def test_mbpp_zeroshot_cot():
778783

779784
def test_mbpp_zeroshot_react():
780785
run_optimizer_mbpp("react")
781-
782-
783-
# def test_valid_experiment_programs(capsys: CaptureFixture[str]) -> None:
784-
# actual_invalid: set[str] = set()
785-
# with_warnings: set[str] = set()
786-
# prompt_library = Path("contrib/prompt_library")
787-
# optimizer_examples = Path("contrib/prompt_library")
788-
# programs = [
789-
# "CoT.pdl",
790-
# "ReAct.pdl",
791-
# "ReWoo.pdl",
792-
# "tools.pdl",
793-
# "examples/evalplus/general.pdl",
794-
# "examples/fever/general.pdl",
795-
# "examples/gsm8k/general.pdl",
796-
# ]
797-
# program_paths = [prompt_library / p for p in programs]
798-
# for yaml_file_name in program_paths:
799-
# try:
800-
# _ = parse_file(yaml_file_name)
801-
# captured = capsys.readouterr()
802-
# if len(captured.err) > 0:
803-
# with_warnings |= {str(yaml_file_name)}
804-
# except PDLParseError:
805-
# actual_invalid |= {str(yaml_file_name)}
806-
807-
# assert len(actual_invalid) == 0, actual_invalid
808-
# assert len(with_warnings) == 0, with_warnings

0 commit comments

Comments
 (0)