Skip to content

Commit ee45d5e

Browse files
committed
Add optimizer test PDL
Signed-off-by: Claudio Spiess <[email protected]>
1 parent 5001f92 commit ee45d5e

File tree

3 files changed

+82
-7
lines changed

3 files changed

+82
-7
lines changed

src/pdl/optimize/pdl_optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def sample_candidates(
161161

162162
if (
163163
"prompt_pattern" in self.config.variables
164-
and "cot" in self.config.variables["prompt_pattern"]
164+
and "cot" in self.config.variables.get("prompt_pattern", [])
165+
and 0 in self.config.variables.get("num_demonstrations", [])
165166
):
166167
cot_candidate = {
167168
k: self.sample_random_index(v) for k, v in self.config.variables.items()
@@ -183,9 +184,8 @@ def sample_candidates(
183184
k: self.sample_random_index(v) for k, v in self.config.variables.items()
184185
}
185186
if (
186-
"num_demonstrations" in variable_instance
187-
and variable_instance["num_demonstrations"] == 0
188-
and variable_instance["prompt_pattern"] == "cot"
187+
variable_instance.get("num_demonstrations") == 0
188+
and variable_instance.get("prompt_pattern") == "cot"
189189
):
190190
if variable_instance["prompt_pattern"] in zero_shots_seen:
191191
continue

src/pdl/pdl_dumper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,7 @@ def build_exclude(obj: Any, regex: re.Pattern[str]) -> Any:
416416
for k, v in obj.items():
417417
nested = build_exclude(v, regex)
418418
if nested:
419-
out[k] = nested
420-
if k == "root":
421-
out |= out["root"]
419+
out[k] = {**nested, "__all__": nested}
422420

423421
return out or None
424422

@@ -440,6 +438,7 @@ def dump_program_exclude_internals(program: Program) -> str:
440438
# pattern for internal pdl__ fields
441439
regex = re.compile(r"^pdl__.*")
442440
exclude = build_exclude(program, regex)
441+
exclude |= exclude.get("root", {})
443442
return dump_program(program, exclude=exclude)
444443

445444

tests/data/optimizer_gsm8k.pdl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
description: Demo of template
2+
defs:
3+
cot:
4+
import: ../../contrib/prompt_library/CoT
5+
react:
6+
import: ../../contrib/prompt_library/ReAct
7+
rewoo:
8+
import: ../../contrib/prompt_library/ReWoo
9+
tools:
10+
import: ../../contrib/prompt_library/tools
11+
12+
chain_of_thought:
13+
function:
14+
question: str
15+
model: str
16+
examples:
17+
{ list: { obj: { question: str, reasoning: str, answer: str } } }
18+
return:
19+
lastOf:
20+
- call: ${ cot.fewshot_cot }
21+
args:
22+
examples: ${ examples }
23+
- "Question: ${ question }\n"
24+
- "Answer: Let's think step by step. "
25+
- model: ${ model }
26+
def: answer
27+
parameters:
28+
max_tokens: 1024
29+
temperature: 0
30+
stop:
31+
- "<|endoftext|>"
32+
- "Question:"
33+
include_stop_sequence: false
34+
mock_response: "144"
35+
- data:
36+
answer: ${ answer|trim }
37+
match: ${ prompt_pattern }
38+
with:
39+
# CoT
40+
- case: cot
41+
then:
42+
text:
43+
- "Answer the questions to the best of your abilities.\n\n"
44+
- call: ${ chain_of_thought }
45+
def: ANSWER
46+
contribute: []
47+
args:
48+
examples: "${ demonstrations }"
49+
question: "${ question|trim }"
50+
model: "${ model }"
51+
- "\nThe answer is ${ ANSWER.answer|trim }"
52+
53+
# ReAct
54+
- case: react
55+
then:
56+
text:
57+
call: ${ react.react }
58+
args:
59+
task: "Question: ${ question|trim }"
60+
model: ${ model }
61+
tool_schema: ${ tools.tool_schema }
62+
tools: ${ tools.tools }
63+
trajectories: ${ demonstrations }
64+
65+
# ReWOO
66+
- case: rewoo
67+
then:
68+
text:
69+
call: ${ rewoo.rewoo }
70+
args:
71+
task: ${ question|trim }
72+
model: ${ model }
73+
tool_schema: ${ tools.tool_schema }
74+
tools: ${ tools.tools }
75+
trajectories: ${ demonstrations }
76+
show_plans: false

0 commit comments

Comments
 (0)