Skip to content

Commit cdfa861

Browse files
ErnestinaQiuroot
andauthored
[Hackathon 5th No.73] ToT (PaddlePaddle#7660)
* Hackathon TASK73 ToT 1. finish meta/llama2 version * update readme tutorial * modify according to Lint * modify according Link 1. resolve one unused variable * Delete LICENSE * Update LICENSE * black format * isort format * Update search_crosswords-dfs.ipynb * update files formats * Update LICENSE * Update LICENSE * Update LICENSE * Update LICENSE * delete test data * delete some unnecessary files 1. delete some unnecessary files according to comments. * add paddlenlp-llama2 1. add llama2 in paddlenlp * fix one bug * fix outputs bug 1. format data structure * delete meta/llama2 * modify according to comments 1. add acknow into readme 2.change png into url in readme 3. add all the models supported by paddlenlp * change according to comments * Delete .gitignore * Create .gitignore * Move directory * Add tree of thoughts scripts * add first dir * add note * Update README.md add test results of facebook/llama-2-7b-chat and llama-2-13b-chat * Update requirements.txt delete unnecessary packages * Update demo.py add Ernie * Update .gitignore delete pyproject.toml * Update run.py add Ernie * Update __init__.py add Ernie * chat templates * add Ernie * Update llama.py 兼容Ernie * Update bfs.py 兼容Ernie * Update models.py 兼容Ernie * Update run.py * format style * format style * format style * format style * format style * format style * format style * format style * 删掉重复的“测试结果” * 删除Ernie的token,设置环境变量解决 * format style * format style * 删除注释掉的代码 --------- Co-authored-by: root <[email protected]>
1 parent 3a42280 commit cdfa861

File tree

20 files changed

+1989
-1
lines changed

20 files changed

+1989
-1
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,4 @@ FETCH_HEAD
123123

124124
# vscode
125125
.vscode
126-
./ppdiffusers/ppdiffusers/version.py
126+
./ppdiffusers/ppdiffusers/version.py
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Tree of Thoughts (ToT)
2+
3+
![teaser](https://github.com/PaddlePaddle/PaddleNLP/assets/48557439/30f9e365-398a-4822-b3c2-a0768f70e310)
4+
5+
论文[Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) 的代码 prompts 和 model outputs 实现。
6+
7+
8+
## Setup
9+
1. 安装
10+
```bash
11+
git clone [email protected]:PaddlePaddle/PaddleNLP.git
12+
cd pipelines/examples/tree-of-thought/
13+
pip install -r requirements.txt
14+
```
15+
16+
2. 请从 https://github.com/ErnestinaQiu/tree-of-thought-llm/tree/master/src/tot/data 获取测试数据,并放置在 pipelines/examples/tree-of-thought/tree/master/src/tot/data
17+
18+
## Quick Start
19+
以下是脚本,该脚本尝试使用4 5 6 10解决24点游戏(由于使用llama-7b-chat,可能会稍慢一些)
20+
21+
22+
在目录 pipelines/examples/agents/tree-of-thought-llm 下运行
23+
24+
```
25+
python demo.py
26+
```
27+
28+
以下是文档的中文翻译:
29+
30+
```python
31+
import argparse
32+
from tot.methods.bfs import solve
33+
from tot.tasks.game24 import Game24Task
34+
35+
args = argparse.Namespace(backend='llama-2-7b-chat', temperature=0.6, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)
36+
37+
task = Game24Task()
38+
ys, infos = solve(args, task, 900)
39+
print(ys[0])
40+
```
41+
42+
输出结果可能如下(注意它不是确定性的,有时输出可能是错误的):
43+
```
44+
10 - 4 = 6 (left: 5 6 6)
45+
5 * 6 = 30 (left: 6 30)
46+
30 - 6 = 24 (left: 24)
47+
Answer: (5 * (10 - 4)) - 6 = 24
48+
```
49+
50+
## 论文实验
51+
52+
通过 ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh`` 运行实验。
53+
54+
非常简单的 ``run.py`` 实现了 ToT + BFS 算法,以及朴素的 IO/CoT 抽样。一些关键参数:
55+
56+
- ``--naive_run``: 如果为 True,则运行朴素的 IO/CoT 抽样,而不是 ToT + BFS。
57+
- ``--prompt_sample`` (choices=[``standard``, ``cot``]): 抽样提示
58+
- ``--method_generate`` (choices=[``sample``, ``propose``]): 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏)
59+
- ``--method_evaluate`` (choices=[``value``, ``vote``]): 状态评估器,是独立使用价值状态(用于24点游戏)还是对状态进行投票(用于创意写作)
60+
- ``--n_generate_sample``: 提示进行思维生成的次数
61+
- ``--n_evaluate_sample``: 提示进行状态评估的次数
62+
- ``--n_select_sample``: 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中)
63+
64+
## 论文轨迹
65+
66+
``logs/`` 包含论文实验的所有轨迹,除了 ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json``,该文件是在论文之后重新生成的(因为原始实验是在笔记本中进行的),由于 GPT 解码中的随机性,得分从原来的 74\% 下降到了 69\%。我们希望将来汇总多次运行以考虑抽样随机性,并更新论文,但这不应影响论文的主要结论。
67+
68+
## 论文实验的任务脚本
69+
### crosswords(填字游戏)
70+
```
71+
python run.py \
72+
--task crosswords \ # 任务名:填字游戏
73+
--task_start_index 0 \ # 填字游戏任务数据集中开始的序号
74+
--task_end_index 20 \ # 填字游戏任务数据集中结束的序号
75+
--naive_run \
76+
--prompt_sample cot \ # 抽样提示的方式, cot
77+
--n_generate_sample 10 # 提示进行思维生成的次数, 10次
78+
```
79+
80+
```
81+
python run.py \
82+
--task crosswords \
83+
--task_start_index 0 \
84+
--task_end_index 20 \
85+
--naive_run \ # 运行朴素的 IO/CoT 抽样
86+
--prompt_sample standard \ # 抽样提示的方式, standard
87+
--n_generate_sample 10
88+
```
89+
90+
### game24(24点游戏)
91+
```
92+
python run.py \
93+
--task game24 \ # 任务名:24点游戏
94+
--task_start_index 900 \ # 24点游戏任务数据集中开始的序号
95+
--task_end_index 1000 \ # 24点游戏任务数据集中结束的序号
96+
--method_generate propose \ # 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏)
97+
--method_evaluate value \ # 状态评估器,独立使用价值状态(用于24点游戏)
98+
--method_select greedy \ # 策略选择,"greedy"(贪婪)
99+
--n_evaluate_sample 3 \ # 提示进行状态评估的次数
100+
--n_select_sample 5 \ # 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中)
101+
```
102+
103+
```
104+
python run.py \
105+
--task game24 \
106+
--task_start_index 900 \
107+
--task_end_index 1000 \
108+
--naive_run \ # 运行朴素的 IO/CoT 抽样
109+
--prompt_sample cot \ # 抽样提示的方式, cot
110+
--n_generate_sample 100 \
111+
```
112+
113+
```
114+
python run.py \
115+
--task game24 \
116+
--task_start_index 900 \
117+
--task_end_index 1000 \
118+
--naive_run \
119+
--prompt_sample standard \
120+
--n_generate_sample 100 \
121+
```
122+
123+
### text(创意写作)
124+
```
125+
python run.py \
126+
--task text \ # 任务名:创意写作
127+
--task_start_index 0 \ # 创意写作任务数据集中开始的序号
128+
--task_end_index 100 \ # 创意写作任务数据集中结束的序号
129+
--method_generate sample \ # 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏)
130+
--method_evaluate vote \ # 状态评估器,对状态进行投票(用于创意写作)
131+
--method_select greedy \ # 策略选择,"sample"(举例)
132+
--n_generate_sample 5 \ # 提示进行思维生成的次数
133+
--n_evaluate_sample 5 \ # 提示进行状态评估的次数
134+
--n_select_sample 1 \ # 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中)
135+
--prompt_sample cot \
136+
--temperature 1.0 \
137+
```
138+
139+
```
140+
python run.py \
141+
--task text \
142+
--task_start_index 0 \
143+
--task_end_index 100 \
144+
--naive_run \ # 运行朴素的 IO/CoT 抽样
145+
--prompt_sample cot \ # 抽样提示的方式, cot
146+
--n_generate_sample 10 \
147+
--temperature 1.0 \
148+
```
149+
150+
```
151+
python run.py \
152+
--task text \
153+
--task_start_index 0 \
154+
--task_end_index 100 \
155+
--naive_run \ # 运行朴素的 IO/CoT 抽样
156+
--prompt_sample standard \ # 抽样提示的方式, standard
157+
--n_generate_sample 10 \
158+
--temperature 1.0 \
159+
```
160+
161+
## 测试结果
162+
本测试采用的是paddlenlp中facebook/llama-2-7b-chat 和 facebook/llama-2-13b-chat.使用的参数为 temperature=0.6, decode_strategy为"greedy_search",max_new_tokens=512,结果如下
163+
|model|method|acc|
164+
|----|----|----|
165+
|llama-2-7b-chat|cot|0|
166+
|llama-2-7b-chat|standard sampling| 0|
167+
|llama-2-7b-chat|ToT| 3%|
168+
|llama-2-13b-chat|cot|0|
169+
|llama-2-13b-chat|standard sampling|0|
170+
|llama-2-13b-chat|ToT|2%|
171+
172+
173+
## 如何添加新任务
174+
175+
设置一个新任务很容易,主要包括两个步骤。
176+
*``tot/tasks/`` 中设置一个新的任务类和任务文件在 ``tot/data/`` 中。查看 ``tot/tasks/game24.py`` 以获取示例。将任务添加到 ``tot/tasks/__init__.py`` 中。
177+
*``tot/prompts/`` 中设置任务特定的提示。查看 ``tot/prompts/game24.py`` 以获取示例。根据任务的性质,选择 ``--method_generate`` (choices=[``sample``, ``propose``]) 和 ``--method_evaluate`` (choices=[``value``, ``vote``]) 及其相应的提示。
178+
179+
180+
## 致谢
181+
182+
我们借鉴了Shunyu Yao ect.出色的框架设计,在此对Tree of Thoughts作者及其开源社区表示感谢。
183+
184+
We learn form the excellent framework design of Shunyu Yao, and we would like to express our thanks to the authors of Tree of Thoughts and their open source community.
185+
186+
```bibtex
187+
@misc{yao2023tree,
188+
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
189+
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
190+
year={2023},
191+
eprint={2305.10601},
192+
archivePrefix={arXiv},
193+
primaryClass={cs.CL}
194+
}
195+
```
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
from src.llm import Ernie, Ernie_llm_list, llamaChatCompletion, llm_config
18+
from src.tot.methods.bfs import solve
19+
from src.tot.tasks.game24 import Game24Task
20+
21+
args = argparse.Namespace(
22+
backend="llama-2-7b-chat",
23+
temperature=0.6,
24+
task="game24",
25+
naive_run=False,
26+
prompt_sample=None,
27+
method_generate="propose",
28+
method_evaluate="value",
29+
method_select="greedy",
30+
n_generate_sample=1,
31+
n_evaluate_sample=3,
32+
n_select_sample=5,
33+
log_fp="log.txt",
34+
)
35+
36+
task = Game24Task()
37+
if args.backend in llm_config.keys():
38+
chatter = llamaChatCompletion(args.backend)
39+
elif args.backend in Ernie_llm_list:
40+
chatter = Ernie(model=args.backend)
41+
ys, infos = solve(args, task, 900, chatter=chatter)
42+
print(ys[0])
43+
print(infos)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
aiohttp==3.8.4
2+
aiosignal==1.3.1
3+
async-timeout==4.0.2
4+
attrs==23.1.0
5+
certifi==2023.5.7
6+
charset-normalizer==3.1.0
7+
frozenlist==1.3.3
8+
idna==3.4
9+
mpmath==1.3.0
10+
multidict==6.0.4
11+
numpy==1.24.3
12+
requests==2.31.0
13+
sympy==1.12
14+
tqdm==4.65.0
15+
urllib3==2.0.2
16+
yarl==1.9.2
17+
pandas==2.0.3
18+
erniebot==0.5.0
19+
paddlenlp==2.7.1
20+
paddlepaddle-gpu==2.6.0
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# coding=utf8, ErnestinaQiu
2+
3+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import argparse
18+
import json
19+
import os
20+
import time
21+
22+
from src.llm.llama import Ernie, Ernie_llm_list, llamaChatCompletion, llm_config
23+
from src.tot.methods.bfs import naive_solve, solve
24+
from src.tot.models import gpt_usage
25+
from src.tot.tasks import get_task
26+
27+
28+
def run(args, chatter):
29+
task = get_task(args.task)
30+
logs, cnt_avg, cnt_any = [], 0, 0
31+
if args.naive_run:
32+
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json"
33+
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_select}_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt"
34+
else:
35+
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json"
36+
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt"
37+
os.makedirs(os.path.dirname(file), exist_ok=True)
38+
39+
for i in range(args.task_start_index, args.task_end_index):
40+
args.log_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.log"
41+
args.query_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_query.log"
42+
f = open(args.log_fp, "a", encoding="utf8")
43+
f.write(f"------ index: {i}")
44+
f.close()
45+
46+
f = open(args.query_fp, "a", encoding="utf8")
47+
f.write(f"------ index: {i}")
48+
f.close()
49+
50+
chatter.query = []
51+
chatter.tokenizer.init_chat_template(
52+
os.path.join(os.getcwd(), "pipelines", "examples", "tree-of-thought", "src", "llm", "chat_template.json")
53+
)
54+
55+
# solve
56+
if args.naive_run:
57+
ys, info = naive_solve(args, task, i, chatter=chatter, args=args)
58+
else:
59+
ys, info = solve(args, task, i, chatter=chatter, args=args)
60+
61+
# log
62+
infos = [task.test_output(i, y) for y in ys]
63+
info.update({"idx": i, "ys": ys, "infos": infos, "usage_so_far": gpt_usage(args.backend)})
64+
logs.append(info)
65+
with open(file, "w") as f:
66+
json.dump(logs, f, indent=4)
67+
68+
# log main metric
69+
accs = [info["r"] for info in infos]
70+
cnt_avg += sum(accs) / len(accs)
71+
cnt_any += any(accs)
72+
mes = f"{i}, 'sum(accs)', {sum(accs)}, 'cnt_avg', {cnt_avg}, 'cnt_any', {cnt_any}, '\n'"
73+
f = open(metric_fp, "a", encoding="utf8")
74+
f.write(mes)
75+
f.close()
76+
77+
f = open(args.query_fp, "a", encoding="utf8")
78+
f.write(json.dumps(chatter.query))
79+
f.close()
80+
81+
n = args.task_end_index - args.task_start_index
82+
mes2 = f"cnt_avg / n: {cnt_avg / n}, cnt_any / n: {cnt_any / n}"
83+
mes3 = f"'usage_so_far', {gpt_usage(args.backend)}"
84+
f = open(metric_fp, "a", encoding="utf8")
85+
f.write(mes2)
86+
f.write(mes3)
87+
f.close()
88+
89+
90+
llm_backend_choices = list(llm_config.keys())
91+
92+
93+
def parse_args():
94+
args = argparse.ArgumentParser()
95+
args.add_argument("--backend", type=str, choices=llm_backend_choices, default="llama-2-7b-chat")
96+
args.add_argument("--temperature", type=float, default=0.6)
97+
98+
args.add_argument("--task", type=str, required=True, choices=["game24", "text", "crosswords"])
99+
args.add_argument("--task_start_index", type=int, default=900)
100+
args.add_argument("--task_end_index", type=int, default=1000)
101+
102+
args.add_argument("--naive_run", action="store_true")
103+
args.add_argument(
104+
"--prompt_sample", type=str, choices=["standard", "cot"]
105+
) # only used when method_generate = sample, or naive_run
106+
107+
args.add_argument("--method_generate", type=str, choices=["sample", "propose"])
108+
args.add_argument("--method_evaluate", type=str, choices=["value", "vote"])
109+
args.add_argument("--method_select", type=str, choices=["sample", "greedy"], default="greedy")
110+
args.add_argument("--n_generate_sample", type=int, default=1) # only thing needed if naive_run
111+
args.add_argument("--n_evaluate_sample", type=int, default=1)
112+
args.add_argument("--n_select_sample", type=int, default=1)
113+
114+
args.add_argument("--query_fp", type=str, default=f"./logs/default/query_{int(time.time())}.log")
115+
116+
args = args.parse_args()
117+
return args
118+
119+
120+
if __name__ == "__main__":
121+
args = parse_args()
122+
if args.backend in llm_backend_choices:
123+
chatter = llamaChatCompletion(args.backend)
124+
elif args.backend in Ernie_llm_list:
125+
chatter = Ernie(model=args.backend)
126+
run(args, chatter=chatter)

0 commit comments

Comments
 (0)