Skip to content

Commit c845250

Browse files
authored
AutoPDL Simplification (#1088)
* autopdl simplifications Signed-off-by: Claudio Spiess <[email protected]> * Lint Signed-off-by: Claudio Spiess <[email protected]> * Fix type error? Signed-off-by: Claudio Spiess <[email protected]> * fmt Signed-off-by: Claudio Spiess <[email protected]> * fix name & skip test Signed-off-by: Claudio Spiess <[email protected]> * Skip additional Signed-off-by: Claudio Spiess <[email protected]> * update gsm8k example Signed-off-by: Claudio Spiess <[email protected]> --------- Signed-off-by: Claudio Spiess <[email protected]>
1 parent 179e092 commit c845250

18 files changed

+322
-101
lines changed

docs/autopdl.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ Finally, we can run the example like so:
127127

128128
``` { .bash .copy .annotate linenums="1" }
129129
cd examples/optimizer
130-
python optimize.py optimize --config gsm8k_optimizer_config.yml --dataset-path ../../var/gsm8k_trajectified gsm8k.pdl
130+
python optimize.py optimize --config gsm8k_optimizer_config.yml --dataset-path ../../var/gsm8k_trajectified
131131
```
132132

133133
This will report details about the optimization process, such as the number of candidates evaluated. The output will look something like this:

examples/optimizer/bea19.pdl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
lastOf:
2+
- "Here are examples of grammatically incorrect sentences and their corrected versions:\n\n"
3+
- for:
4+
example: ${ demonstrations }
5+
repeat:
6+
text: "${ example.broken } -> ${ example.sentence }"
7+
join:
8+
with: "\n\n"
9+
- "Correct the following sentence:\n\n${ broken }\nHere's the corrected sentence:\n\n"
10+
- model: ${ model }
11+
parameters:
12+
max_tokens: 1024
13+
temperature: 0
14+
stop:
15+
- "<|endoftext|>"
16+
- "Question:"
17+
include_stop_sequence: false

examples/optimizer/bea19_example.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
pdl_path: examples/optimizer/bea19.pdl # Path to the PDL file to optimize
2+
# benchmark: gretel-math # Name our benchmark
3+
dataset:
4+
train: bea19_jsonl/train.jsonl # Path to the training split in JSONL format
5+
test: bea19_jsonl/test.jsonl # Path to the test split in JSONL format
6+
validation: bea19_jsonl/validation.jsonl # Path to the validation split in JSONL format
7+
8+
demonstrations_variable_name: demonstrations # variable name to insert demonstrations into
9+
demonstration_columns:
10+
- broken # column name for the question in the dataset
11+
- sentence # column name for the answer in the dataset
12+
13+
instance_columns:
14+
- broken # column name for the question in the dataset
15+
16+
groundtruth_column: sentence # column name for the ground truth in the dataset
17+
18+
eval_pdl: examples/optimizer/eval_levenshtein.pdl # Path to the PDL file for evaluation
19+
20+
budget: null # Set a budget, can be number of iterations, or a duration string e.g. "2h"
21+
budget_growth: double # double validation set size each iteration
22+
# or to_max: reach max_test_set_size by final iteration
23+
initial_test_set_size: 1 # size of test set in first iteration
24+
max_test_set_size: 1 # maximum test set size
25+
num_candidates: 100 # how many candidates to evaluate
26+
parallelism: 1 # how many threads to run evaluations across
27+
shuffle_test: false # shuffling of test set
28+
test_set_name: test # name of test set
29+
train_set_name: train # name of train set
30+
validation_set_name: validation # name of validation set
31+
variables: # define discrete options to sample from
32+
model: # set ${ model } variable
33+
- watsonx/meta-llama/llama-3-2-3b-instruct
34+
num_demonstrations: # overrides num demonstrations above
35+
- 0
36+
- 3
37+
- 5
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
defs:
2+
score:
3+
function:
4+
document: string
5+
ground_truth: string
6+
return:
7+
lang: python
8+
fallback: 0
9+
code: |
10+
import textdistance
11+
result = textdistance.levenshtein.normalized_similarity(document, ground_truth)

examples/optimizer/fever_evaluator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,12 @@ def extract_answer(self, document: str) -> bool | None:
8282

8383
return None
8484

85-
def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
86-
return answer == truth or document.lower().endswith(str(truth).lower())
85+
def score(self, document: str, ground_truth: Any) -> float:
86+
answer = self.extract_answer(document)
87+
if answer is None:
88+
return 0.0
89+
90+
return float(
91+
answer == ground_truth
92+
or document.lower().endswith(str(ground_truth).lower())
93+
)

examples/optimizer/gsm8k_evaluator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ def get_scope(self) -> ScopeType:
6464
scope["reasoning"] = self.example["reasoning"]
6565
return empty_scope | scope
6666

67-
def extract_answer(self, document: str) -> Any:
68-
return extract_math_answer(document)
69-
70-
def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
71-
return answer == truth or document.endswith(f" {truth}")
67+
def score(self, document: str, ground_truth: Any) -> float:
68+
answer = extract_math_answer(document)
69+
return float(answer == ground_truth or document.endswith(f" {ground_truth}"))

examples/optimizer/gsm8k_optimizer_config.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
benchmark: gsm8k # Name our benchmark
1+
pdl_path: gsm8k.pdl # Path to the PDL file to optimize
2+
dataset: gsm8k # Name our benchmark
23
budget: null # Set a budget, can be number of iterations, or a duration string e.g. "2h"
34
budget_growth: double # double validation set size each iteration
45
# or to_max: reach max_test_set_size by final iteration
@@ -12,6 +13,21 @@ test_set_name: test # name of test set
1213
train_set_name: train # name of train set
1314
validation_set_name: validation # name of validation set
1415
demonstrations_variable_name: demonstrations # variable name to insert demonstrations into
16+
demonstration_columns:
17+
- question # column name for the question in the dataset
18+
- reasoning
19+
- answer
20+
- traj_keys
21+
- traj_values
22+
- rewoo_traj_keys
23+
- rewoo_traj_values
24+
25+
instance_columns:
26+
- question
27+
- reasoning
28+
29+
groundtruth_column: answer # column name for the ground truth in the dataset
30+
1531
variables: # define discrete options to sample from
1632
model: # set ${ model } variable
1733
- watsonx/meta-llama/llama-3-2-3b-instruct

examples/optimizer/gsmhard_evaluator.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from pdl.pdl_interpreter import empty_scope
77

88

9-
def is_float(s: str) -> str:
9+
def is_float(s: str | float) -> str:
1010
try:
1111
f = float(s)
1212
return f"{f:.2f}"
1313
except Exception:
14-
return s
14+
return str(s)
1515

1616

1717
class GsmHardEvaluator(OptimizerEvaluator):
@@ -74,10 +74,16 @@ def get_scope(self) -> ScopeType:
7474
scope["question"] = self.example["input"]
7575
return empty_scope | scope
7676

77-
def extract_answer(self, document: str) -> float | int | None:
78-
return extract_math_answer(document)
77+
def score(self, document: str, ground_truth: Any) -> float:
78+
answer = extract_math_answer(document)
79+
if answer is None:
80+
return 0.0
7981

80-
def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
8182
answerf = is_float(answer)
82-
truthf = is_float(truth)
83-
return answer == truth or answerf == truthf or document.endswith(f" {truth}")
83+
truthf = is_float(ground_truth)
84+
85+
return float(
86+
answer == ground_truth
87+
or answerf == truthf
88+
or document.endswith(f" {ground_truth}")
89+
)

examples/optimizer/mbpp_evaluator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def extract_answer(self, document: str) -> str:
6565
solution = solution.split("```")[1]
6666
return solution.strip()
6767

68-
def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
69-
if answer is None or not isinstance(answer, str):
70-
return False
68+
def score(self, document: str, ground_truth: Any) -> float:
69+
answer = self.extract_answer(document)
70+
if not answer:
71+
return 0.0
7172

7273
retry_parse = False
7374
try:
@@ -78,16 +79,16 @@ def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
7879

7980
if retry_parse:
8081
pattern = r"```(?:python)?\n(.*?)\n```"
81-
match = re.search(pattern, answer, re.DOTALL)
82+
match = re.search(pattern, document, re.DOTALL)
8283
if match:
8384
answer = match.group(1)
8485
try:
8586
ast.parse(answer)
8687
except Exception as e:
8788
print(e)
88-
return False
89+
return 0.0
8990
else:
90-
return False
91+
return 0.0
9192

9293
task_id = self.example["task_id"]
9394

@@ -109,4 +110,4 @@ def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:
109110
base_stat, _ = result["base"]
110111
plus_stat, _ = result["plus"]
111112

112-
return base_stat == "pass" and plus_stat == "pass"
113+
return float(base_stat == "pass" and plus_stat == "pass")

examples/optimizer/optimize.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
from typing import Any
66

77
import yaml
8-
from datasets.load import load_from_disk
8+
from datasets import load_dataset, load_from_disk
99
from fever_evaluator import FEVEREvaluator
1010
from gsm8k_evaluator import Gsm8kEvaluator
1111
from gsmhard_evaluator import GsmHardEvaluator
1212
from mbpp_dataset import MBPPDataset
1313
from mbpp_evaluator import MBPPEvaluator
1414

15-
from pdl.optimize.config_parser import OptimizationConfig
15+
from pdl.optimize.config_parser import JsonlDataset, OptimizationConfig
16+
from pdl.optimize.optimizer_evaluator import OptimizerEvaluator
17+
from pdl.optimize.pdl_evaluator import PdlEvaluator
1618
from pdl.optimize.pdl_optimizer import PDLOptimizer
1719

1820
if __name__ == "__main__":
@@ -38,7 +40,7 @@
3840
"--dataset-path",
3941
help="Path to the dataset directory",
4042
type=Path,
41-
required=True,
43+
required=False,
4244
)
4345
common_parser.add_argument(
4446
"--experiments-path",
@@ -56,11 +58,6 @@
5658
action=argparse.BooleanOptionalAction,
5759
default=False,
5860
)
59-
common_parser.add_argument(
60-
"pdl_file",
61-
type=Path,
62-
help="Path to a PDL file to optimize",
63-
)
6461

6562
# Optimize command
6663
optimize_parser = subparsers.add_parser(
@@ -82,9 +79,6 @@
8279
)
8380

8481
args = parser.parse_args()
85-
if not args.pdl_file.exists():
86-
print("PDL file doesn't exist:", args.pdl_file)
87-
sys.exit(1)
8882

8983
if not args.config.exists():
9084
print("Config file doesn't exist:", args.config)
@@ -100,35 +94,52 @@
10094
traceback.print_last()
10195
sys.exit(1)
10296

97+
if not Path(config.pdl_path).exists():
98+
print("PDL file doesn't exist:", config.pdl_path)
99+
sys.exit(1)
100+
103101
if args.dry:
104102
sys.exit(0)
105103

106104
# Set up dataset and trial thread based on benchmark
107105
dataset: Any
108106
TrialThread: type[
109-
Gsm8kEvaluator | GsmHardEvaluator | FEVEREvaluator | MBPPEvaluator
107+
Gsm8kEvaluator
108+
| GsmHardEvaluator
109+
| FEVEREvaluator
110+
| MBPPEvaluator
111+
| OptimizerEvaluator
110112
]
111113

112-
if config.benchmark == "gsm8k":
114+
if config.dataset == "gsm8k":
113115
dataset = load_from_disk(args.dataset_path)
114116
TrialThread = Gsm8kEvaluator
115-
elif config.benchmark == "gsmhard":
117+
elif config.dataset == "gsmhard":
116118
dataset = load_from_disk(args.dataset_path)
117119
TrialThread = GsmHardEvaluator
118-
elif config.benchmark == "fever":
120+
elif config.dataset == "fever":
119121
fever = load_from_disk(args.dataset_path)
120122
dataset = fever
121123
TrialThread = FEVEREvaluator
122-
elif config.benchmark == "mbpp":
124+
elif config.dataset == "mbpp":
123125
dataset = MBPPDataset(args.dataset_path)
124126
TrialThread = MBPPEvaluator
127+
elif isinstance(config.dataset, (dict, JsonlDataset)):
128+
dataset = load_dataset(
129+
"json",
130+
data_files={
131+
"train": config.dataset.train,
132+
"validation": config.dataset.validation,
133+
"test": config.dataset.test,
134+
},
135+
)
136+
TrialThread = PdlEvaluator
125137
else:
126-
print(f"Unknown benchmark: {config.benchmark}")
138+
print(f"Unknown dataset: {config.dataset}")
127139
sys.exit(1)
128140

129141
# Create optimizer instance
130142
optimizer = PDLOptimizer(
131-
pdl_path=args.pdl_file,
132143
dataset=dataset,
133144
trial_thread=TrialThread,
134145
yield_output=args.yield_output,

0 commit comments

Comments
 (0)