Skip to content

Commit 89a9a60

Browse files
[MCTS] Add self-refined MCTS (#6098)
* add reasoner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update code * delete llama * update prompts * update readme * update readme --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4294ae8 commit 89a9a60

File tree

5 files changed

+324
-3
lines changed

5 files changed

+324
-3
lines changed

applications/ColossalChat/README.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
2828
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
2929
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
30+
- [O1 Journey](#o1-journey)
31+
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
3032
- [FAQ](#faq)
3133
- [How to save/load checkpoint](#faq)
3234
- [How to train with limited resources](#faq)
33-
- [The Plan](#the-plan)
34-
- [Real-time progress](#real-time-progress)
3535
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
3636
- [Quick Preview](#quick-preview)
3737
- [Authors](#authors)
@@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
272272
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
273273
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
274274

275-
### Inference Quantization and Serving - After Training
275+
## Inference Quantization and Serving - After Training
276276

277277
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
278278

@@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
281281
Online inference server scripts can help you deploy your own services.
282282
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
283283

284+
## O1 Journey
285+
### Inference with Self-refined MCTS
286+
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
287+
To run inference with MCTS, simply use the following script.
288+
```python
289+
from coati.reasoner.guided_search.mcts import MCTS
290+
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
291+
292+
problem = "How Many R in 'Strawberry'"
293+
294+
search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
295+
answer = search_tree.simulate()
296+
print(answer)
297+
```
298+
284299
## Coati7B examples
285300

286301
### Generation
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import openai
2+
from openai.types.chat.chat_completion import ChatCompletion
3+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
4+
5+
API_KEY = "Dummy API Key"
6+
7+
8+
def get_client(base_url: str | None = None) -> openai.Client:
9+
return openai.Client(api_key=API_KEY, base_url=base_url)
10+
11+
12+
def chat_completion(
13+
messages: list[ChatCompletionMessageParam],
14+
model: str,
15+
base_url: str | None = None,
16+
temperature: float = 0.8,
17+
**kwargs,
18+
) -> ChatCompletion:
19+
client = get_client(base_url)
20+
response = client.chat.completions.create(
21+
model=model,
22+
messages=messages,
23+
temperature=temperature,
24+
**kwargs,
25+
)
26+
return response
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""
2+
Implementation of MCTS + Self-refine algorithm.
3+
4+
Reference:
5+
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
6+
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
7+
2. https://github.com/BrendanGraham14/mcts-llm/
8+
3. https://github.com/trotsky1997/MathBlackBox/
9+
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import math
15+
from collections import deque
16+
17+
import numpy as np
18+
import tqdm
19+
from coati.reasoner.guided_search.llm import chat_completion
20+
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
21+
from pydantic import BaseModel
22+
23+
24+
class MCTSNode(BaseModel):
25+
"""
26+
Node for MCTS.
27+
"""
28+
29+
answer: str
30+
parent: MCTSNode = None
31+
children: list[MCTSNode] = []
32+
num_visits: int = 0
33+
Q: int = 0
34+
rewards: list[int] = []
35+
36+
def expand_node(self, node) -> None:
37+
self.children.append(node)
38+
39+
def add_reward(self, reward: int) -> None:
40+
self.rewards.append(reward)
41+
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
42+
43+
44+
class MCTS(BaseModel):
45+
"""
46+
Simulation of MCTS process.
47+
"""
48+
49+
problem: str
50+
max_simulations: int
51+
cfg: PromptCFG
52+
C: float = 1.4
53+
max_children: int = 2
54+
epsilon: float = 1e-5
55+
root: MCTSNode = None
56+
57+
def initialization(self):
58+
"""
59+
Root Initiation.
60+
"""
61+
# Dummy answer as root.
62+
base_answer = self.sample_base_answer()
63+
self.root = MCTSNode(answer=base_answer)
64+
self.self_evaluate(self.root)
65+
66+
def is_fully_expanded(self, node: MCTSNode):
67+
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
68+
69+
def select_node(self) -> MCTSNode:
70+
"""
71+
Select next node to explore.
72+
"""
73+
candidates: list[MCTSNode] = []
74+
to_explore = deque([self.root])
75+
76+
while to_explore:
77+
current_node = to_explore.popleft()
78+
if not self.is_fully_expanded(current_node):
79+
candidates.append(current_node)
80+
to_explore.extend(current_node.children)
81+
82+
if not candidates:
83+
return self.root
84+
85+
return max(candidates, key=self.compute_uct)
86+
87+
def self_evaluate(self, node: MCTSNode):
88+
"""
89+
Sample reward of the answer.
90+
"""
91+
reward = self.sample_reward(node)
92+
node.add_reward(reward)
93+
94+
def back_propagation(self, node: MCTSNode):
95+
"""
96+
Back propagate the value of the refined answer.
97+
"""
98+
parent = node.parent
99+
while parent:
100+
best_child_Q = max(child.Q for child in parent.children)
101+
parent.Q = (parent.Q + best_child_Q) / 2
102+
parent.num_visits += 1
103+
parent = parent.parent
104+
105+
def compute_uct(self, node: MCTSNode):
106+
"""
107+
Compute UCT.
108+
"""
109+
if node.parent is None:
110+
return -100
111+
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
112+
113+
def simulate(self):
114+
self.initialization()
115+
for _ in tqdm.tqdm(range(self.max_simulations)):
116+
node = self.select_node()
117+
child = self.self_refine(node)
118+
node.expand_node(child)
119+
self.self_evaluate(child)
120+
self.back_propagation(child)
121+
122+
return self.get_best_answer()
123+
124+
def get_best_answer(self):
125+
to_visit = deque([self.root])
126+
best_node = self.root
127+
128+
while to_visit:
129+
current_node = to_visit.popleft()
130+
if current_node.Q > best_node.Q:
131+
best_node = current_node
132+
to_visit.extend(current_node.children)
133+
134+
return best_node.answer
135+
136+
def self_refine(self, node: MCTSNode):
137+
"""
138+
Refine node.
139+
"""
140+
critique_response = chat_completion(
141+
messages=[
142+
{
143+
"role": "system",
144+
"content": self.cfg.critic_system_prompt,
145+
},
146+
{
147+
"role": "user",
148+
"content": "\n\n".join(
149+
[
150+
f"<problem>\n{self.problem}\n</problem>",
151+
f"<current_answer>\n{node.answer}\n</current_answer>",
152+
]
153+
),
154+
},
155+
],
156+
model=self.cfg.model,
157+
base_url=self.cfg.base_url,
158+
max_tokens=self.cfg.max_tokens,
159+
)
160+
critique = critique_response.choices[0].message.content
161+
assert critique is not None
162+
refined_answer_response = chat_completion(
163+
messages=[
164+
{
165+
"role": "system",
166+
"content": self.cfg.refine_system_prompt,
167+
},
168+
{
169+
"role": "user",
170+
"content": "\n\n".join(
171+
[
172+
f"<problem>\n{self.problem}\n</problem>",
173+
f"<current_answer>\n{node.answer}\n</current_answer>",
174+
f"<critique>\n{critique}\n</critique>",
175+
]
176+
),
177+
},
178+
],
179+
model=self.cfg.model,
180+
base_url=self.cfg.base_url,
181+
max_tokens=self.cfg.max_tokens,
182+
)
183+
refined_answer = refined_answer_response.choices[0].message.content
184+
assert refined_answer is not None
185+
186+
return MCTSNode(answer=refined_answer, parent=node)
187+
188+
def sample_base_answer(self):
189+
response = chat_completion(
190+
messages=[
191+
{
192+
"role": "system",
193+
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
194+
},
195+
{
196+
"role": "user",
197+
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
198+
},
199+
],
200+
model=self.cfg.model,
201+
base_url=self.cfg.base_url,
202+
max_tokens=self.cfg.max_tokens,
203+
)
204+
assert response.choices[0].message.content is not None
205+
return response.choices[0].message.content
206+
207+
def sample_reward(self, node: MCTSNode):
208+
"""
209+
Calculate reward.
210+
"""
211+
messages = [
212+
{
213+
"role": "system",
214+
"content": self.cfg.evaluate_system_prompt,
215+
},
216+
{
217+
"role": "user",
218+
"content": "\n\n".join(
219+
[
220+
f"<problem>\n{self.problem}\n</problem>",
221+
f"<answer>\n{node.answer}\n</answer>",
222+
]
223+
),
224+
},
225+
]
226+
for attempt in range(3):
227+
try:
228+
response = chat_completion(
229+
messages=messages,
230+
model=self.cfg.model,
231+
base_url=self.cfg.base_url,
232+
max_tokens=self.cfg.max_tokens,
233+
)
234+
assert response.choices[0].message.content is not None
235+
return int(response.choices[0].message.content)
236+
except ValueError:
237+
messages.extend(
238+
[
239+
{
240+
"role": "assistant",
241+
"content": response.choices[0].message.content,
242+
},
243+
{
244+
"role": "user",
245+
"content": "Failed to parse reward as an integer.",
246+
},
247+
]
248+
)
249+
if attempt == 2:
250+
raise
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from pydantic import BaseModel
2+
3+
4+
class PromptCFG(BaseModel):
5+
model: str
6+
base_url: str
7+
max_tokens: int = 4096
8+
critic_system_prompt: str
9+
refine_system_prompt: str
10+
evaluate_system_prompt: str
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
Prompts for Qwen Series.
3+
"""
4+
5+
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
6+
7+
Qwen32B_prompt_CFG = PromptCFG(
8+
base_url="http://0.0.0.0:8008/v1",
9+
model="Qwen2.5-32B-Instruct",
10+
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
11+
"Highlight specific areas that need refinement or correction.",
12+
refine_system_prompt="""# Instruction
13+
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
14+
""",
15+
evaluate_system_prompt=(
16+
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
17+
"Do not give a full score above 95. Make sure the reward score is an integer. "
18+
"Return *ONLY* the score."
19+
),
20+
)

0 commit comments

Comments
 (0)