Skip to content

Commit 618b838

Browse files
authored
gsm8k multi-plan, tree-of-thought, tree-of-thought with few shots (#881)
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 91f2784 commit 618b838

File tree

3 files changed

+596
-0
lines changed

3 files changed

+596
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
description: Grade School Math -- for every problem we generate a plan, then exectute and evaluate it.
2+
defs:
3+
problems:
4+
read: ./test.jsonl
5+
parser: jsonl
6+
7+
MAX_ITERATIONS: 10
8+
N: 3
9+
10+
11+
majority_vote:
12+
function:
13+
numbers: [float]
14+
return:
15+
lang: python
16+
code: |
17+
from collections import Counter
18+
frequency = Counter( ${ numbers })
19+
most_frequent = max(frequency, key=frequency.get)
20+
result = most_frequent
21+
22+
majority_vote_json:
23+
function:
24+
results: [{ "result": float }]
25+
return:
26+
lastOf:
27+
- lang: python
28+
def: numbers
29+
code: |
30+
result = [o["result"] for o in ${ results }]
31+
- call: ${ majority_vote }
32+
args:
33+
numbers: ${ numbers }
34+
35+
planning:
36+
function:
37+
problem: str
38+
demos: [str]
39+
return:
40+
text:
41+
- |
42+
Please generate a high-level plan for solving the following question.
43+
As the first step, just say what method and idea you will use to solve the question.
44+
You can reorganize the information in the question. Do not do the actual calculation.
45+
Keep your response concise and within 80 words.
46+
47+
- for:
48+
demo: ${ demos }
49+
repeat:
50+
${ demo }
51+
join:
52+
with: "\n"
53+
- text:
54+
- "\nProblem:\n"
55+
- ${ problem }
56+
- "\n"
57+
- model: ollama/granite3.2:8b
58+
parameters:
59+
temperature: 0.7
60+
top_p: 0.85
61+
62+
solve:
63+
function:
64+
plan: str
65+
return:
66+
text:
67+
- ${ plan }
68+
- |
69+
70+
The plan looks good! Now, use real numbers and do the calculation. Please solve the question
71+
step-by-step according to the high-level plan. Give me the final answer. Make your response short.
72+
- "\nThe answer is:\n"
73+
- model: ollama/granite3.2:8b
74+
parameters:
75+
temperature: 0.7
76+
top_p: 0.85
77+
78+
extract_final_answer:
79+
function:
80+
solution: str
81+
return:
82+
lastOf:
83+
- ${ solution }
84+
- Extract the result from the above solution into a JSON object with field "result" and a float as value. Remove any dollar signs or other symbols.
85+
- model: ollama/granite3.2:8b
86+
parser: json
87+
def: result
88+
spec: { "result": float }
89+
fallback:
90+
data:
91+
result: 0
92+
93+
compare_to_ground_truth:
94+
function:
95+
result: float
96+
truth: str
97+
return:
98+
lastOf:
99+
- data: ${ truth }
100+
parser:
101+
regex: "(.|\n)*#### (?P<answer>([0-9])*)\n*"
102+
spec:
103+
answer: str
104+
def: ground_truth
105+
- if: ${ result|float == ground_truth.answer|float}
106+
then:
107+
1
108+
else:
109+
0
110+
111+
text:
112+
- defs:
113+
demos:
114+
read: demos.yaml
115+
parser: yaml
116+
for:
117+
problem: ${ problems }
118+
repeat:
119+
repeat:
120+
call: ${ planning }
121+
args:
122+
pdl_context: []
123+
problem: ${ problem.question }
124+
demos: ${ demos }
125+
max_iterations: ${ N }
126+
join:
127+
as: array
128+
max_iterations: ${ MAX_ITERATIONS }
129+
def: plans
130+
join:
131+
as: array
132+
133+
- for:
134+
plans_for_problem: ${ plans }
135+
repeat:
136+
for:
137+
plan: ${ plans_for_problem }
138+
repeat:
139+
repeat:
140+
call: ${ solve }
141+
args:
142+
pdl_context: []
143+
plan: ${ plan }
144+
max_iterations: ${ N }
145+
join:
146+
as: array
147+
join:
148+
as: array
149+
max_iterations: ${ MAX_ITERATIONS }
150+
def: solutions
151+
join:
152+
as: array
153+
154+
- for:
155+
solution: ${ solutions }
156+
repeat:
157+
for:
158+
solutions_for_problem: ${ solution }
159+
repeat:
160+
for:
161+
solution_for_problem: ${ solutions_for_problem }
162+
repeat:
163+
call: ${ extract_final_answer }
164+
args:
165+
pdl_context: []
166+
solution: ${ solution_for_problem }
167+
max_iterations: ${ N }
168+
join:
169+
as: array
170+
join:
171+
as: array
172+
max_iterations: ${ MAX_ITERATIONS }
173+
def: results
174+
join:
175+
as: array
176+
177+
- for:
178+
all_results_for_problem: ${ results }
179+
repeat:
180+
for:
181+
results_for_problem: ${ all_results_for_problem }
182+
repeat:
183+
call: ${ majority_vote_json }
184+
args:
185+
pdl_context: []
186+
results: ${ results_for_problem }
187+
max_iterations: ${ N }
188+
join:
189+
as: array
190+
max_iterations: ${ MAX_ITERATIONS }
191+
def: per_plan_votes
192+
join:
193+
as: array
194+
195+
- for:
196+
votes: ${ per_plan_votes }
197+
repeat:
198+
call: ${ majority_vote }
199+
args:
200+
pdl_context: []
201+
numbers: ${ votes }
202+
max_iterations: ${ MAX_ITERATIONS }
203+
join:
204+
as: array
205+
def: results
206+
207+
- for:
208+
result: ${ results }
209+
problem: ${ problems[:MAX_ITERATIONS] }
210+
repeat:
211+
call: ${ compare_to_ground_truth }
212+
args:
213+
pdl_context: []
214+
result: ${ result }
215+
truth: ${ problem.answer }
216+
max_iterations: ${ MAX_ITERATIONS }
217+
def: stats
218+
join:
219+
as: array
220+
221+
- "\nAccuracy: ${ stats|sum / MAX_ITERATIONS * 100}% "
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
description: Grade School Math -- for every problem we generate a plan, then exectute and evaluate it.
2+
defs:
3+
problems:
4+
read: ./test.jsonl
5+
parser: jsonl
6+
7+
MAX_ITERATIONS: 2
8+
N: 3
9+
10+
11+
majority_vote:
12+
function:
13+
results: [{ "result": float }]
14+
return:
15+
lang: python
16+
code: |
17+
from collections import Counter
18+
numbers = [o["result"] for o in ${ results }]
19+
frequency = Counter(numbers)
20+
most_frequent = max(frequency, key=frequency.get)
21+
result = most_frequent
22+
planning:
23+
function:
24+
problem: str
25+
return:
26+
text:
27+
- |
28+
Please generate a high-level plan for solving the following question.
29+
As the first step, just say what method and idea you will use to solve the question.
30+
You can reorganize the information in the question. Do not do the actual calculation.
31+
Keep your response concise and within 80 words.
32+
33+
- "\nProblem:\n"
34+
- ${ problem }
35+
- "\n"
36+
- model: ollama/granite3.2:8b
37+
parameters:
38+
temperature: 0.7
39+
top_p: 0.85
40+
41+
solve:
42+
function:
43+
plan: str
44+
return:
45+
text:
46+
- ${ plan }
47+
- |
48+
49+
The plan looks good! Now, use real numbers and do the calculation. Please solve the question
50+
step-by-step according to the high-level plan. Give me the final answer. Make your response short.
51+
- "\nThe answer is:\n"
52+
- model: ollama/granite3.2:8b
53+
parameters:
54+
temperature: 0.7
55+
top_p: 0.85
56+
57+
extract_final_answer:
58+
function:
59+
solution: str
60+
return:
61+
lastOf:
62+
- ${ solution }
63+
- Extract the result from the above solution into a JSON object with field "result" and a float as value. Remove any dollar signs or other symbols.
64+
- model: ollama/granite3.2:8b
65+
parser: json
66+
def: result
67+
spec: { "result": float }
68+
fallback:
69+
data:
70+
result: 0
71+
72+
compare_to_ground_truth:
73+
function:
74+
result: float
75+
truth: str
76+
return:
77+
lastOf:
78+
- data: ${ truth }
79+
parser:
80+
regex: "(.|\n)*#### (?P<answer>([0-9])*)\n*"
81+
spec:
82+
answer: str
83+
def: ground_truth
84+
- if: ${ result|float == ground_truth.answer|float}
85+
then:
86+
1
87+
else:
88+
0
89+
90+
text:
91+
- for:
92+
problem: ${ problems }
93+
repeat:
94+
repeat:
95+
call: ${ planning }
96+
args:
97+
pdl_context: []
98+
problem: ${ problem.question }
99+
max_iterations: ${ N }
100+
join:
101+
as: array
102+
max_iterations: ${ MAX_ITERATIONS }
103+
def: plans
104+
join:
105+
as: array
106+
107+
- for:
108+
plans_for_problem: ${ plans }
109+
repeat:
110+
for:
111+
plan: ${ plans_for_problem }
112+
repeat:
113+
call: ${ solve }
114+
args:
115+
pdl_context: []
116+
plan: ${ plan }
117+
join:
118+
as: array
119+
max_iterations: ${ MAX_ITERATIONS }
120+
def: solutions
121+
join:
122+
as: array
123+
124+
- for:
125+
solution: ${ solutions }
126+
repeat:
127+
for:
128+
solution_for_problem: ${ solution }
129+
repeat:
130+
call: ${ extract_final_answer }
131+
args:
132+
pdl_context: []
133+
solution: ${ solution_for_problem }
134+
join:
135+
as: array
136+
max_iterations: ${ MAX_ITERATIONS }
137+
def: results
138+
join:
139+
as: array
140+
141+
- for:
142+
results_for_problem: ${ results }
143+
repeat:
144+
call: ${ majority_vote }
145+
args:
146+
pdl_context: []
147+
results: ${ results_for_problem }
148+
max_iterations: ${ MAX_ITERATIONS }
149+
def: votes
150+
join:
151+
as: array
152+
153+
- for:
154+
result: ${ votes }
155+
problem: ${ problems[:MAX_ITERATIONS] }
156+
repeat:
157+
call: ${ compare_to_ground_truth }
158+
args:
159+
pdl_context: []
160+
result: ${ result }
161+
truth: ${ problem.answer }
162+
max_iterations: ${ MAX_ITERATIONS }
163+
def: stats
164+
join:
165+
as: array
166+
167+
- "\nAccuracy: ${ stats|sum / MAX_ITERATIONS * 100}% "

0 commit comments

Comments
 (0)