Skip to content

Commit 63b25fb

Browse files
authored
Merge pull request #152 from CerebrasResearch/cepo_rc
CePO
2 parents a1b5840 + a2c301b commit 63b25fb

File tree

7 files changed

+558
-36
lines changed

7 files changed

+558
-36
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,6 @@ cython_debug/
164164
# Ignore Mac DS_Store files
165165
.DS_Store
166166
**/.DS_Store
167+
168+
# VS Code
169+
.vscode/

README.md

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -212,22 +212,23 @@ response = client.chat.completions.create(
212212
213213
## Implemented techniques
214214

215-
| Approach | Slug | Description |
216-
| ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- |
217-
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
218-
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
219-
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
220-
| Self-Consistency | `self_consistency` | Implements an advanced self-consistency method |
221-
| Z3 Solver | `z3` | Utilizes the Z3 theorem prover for logical reasoning |
222-
| R* Algorithm | `rstar` | Implements the R* algorithm for problem-solving |
223-
| LEAP | `leap` | Learns task-specific principles from few shot examples |
224-
| Round Trip Optimization | `rto` | Optimizes responses through a round-trip process |
225-
| Best of N Sampling | `bon` | Generates multiple responses and selects the best one |
226-
| Mixture of Agents | `moa` | Combines responses from multiple critiques |
227-
| Monte Carlo Tree Search | `mcts` | Uses MCTS for decision-making in chat responses |
228-
| PV Game | `pvg` | Applies a prover-verifier game approach at inference time |
229-
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
230-
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |
215+
| Approach | Slug | Description |
216+
| ------------------------------------ | ------------------ | ---------------------------------------------------------------------------------------------- |
217+
| Cerebras Planning and Optimimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
218+
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
219+
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
220+
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
221+
| Self-Consistency | `self_consistency` | Implements an advanced self-consistency method |
222+
| Z3 Solver | `z3` | Utilizes the Z3 theorem prover for logical reasoning |
223+
| R* Algorithm | `rstar` | Implements the R* algorithm for problem-solving |
224+
| LEAP | `leap` | Learns task-specific principles from few shot examples |
225+
| Round Trip Optimization | `rto` | Optimizes responses through a round-trip process |
226+
| Best of N Sampling | `bon` | Generates multiple responses and selects the best one |
227+
| Mixture of Agents | `moa` | Combines responses from multiple critiques |
228+
| Monte Carlo Tree Search | `mcts` | Uses MCTS for decision-making in chat responses |
229+
| PV Game | `pvg` | Applies a prover-verifier game approach at inference time |
230+
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
231+
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |
231232

232233
## Implemented plugins
233234

@@ -244,22 +245,38 @@ response = client.chat.completions.create(
244245

245246
optillm supports various command-line arguments and environment variables for configuration.
246247

247-
| Parameter | Description | Default Value |
248-
|--------------------------|-----------------------------------------------------------------|-----------------|
249-
| `--approach` | Inference approach to use | `"auto"` |
250-
| `--simulations` | Number of MCTS simulations | 2 |
251-
| `--exploration` | Exploration weight for MCTS | 0.2 |
252-
| `--depth` | Simulation depth for MCTS | 1 |
253-
| `--best-of-n` | Number of samples for best_of_n approach | 3 |
254-
| `--model` | OpenAI model to use | `"gpt-4o-mini"` |
255-
| `--base-url` | Base URL for OpenAI compatible endpoint | `""` |
256-
| `--rstar-max-depth` | Maximum depth for rStar algorithm | 3 |
257-
| `--rstar-num-rollouts` | Number of rollouts for rStar algorithm | 5 |
258-
| `--rstar-c` | Exploration constant for rStar algorithm | 1.4 |
259-
| `--n` | Number of final responses to be returned | 1 |
260-
| `--return-full-response` | Return the full response including the CoT with <thinking> tags | `False` |
261-
| `--port` | Specify the port to run the proxy | 8000 |
262-
| `--optillm-api-key` | Optional API key for client authentication to optillm | `""` |
248+
| Parameter | Description | Default Value |
249+
|-------------------------------------|-----------------------------------------------------------------|-----------------|
250+
| `--approach` | Inference approach to use | `"auto"` |
251+
| `--simulations` | Number of MCTS simulations | 2 |
252+
| `--exploration` | Exploration weight for MCTS | 0.2 |
253+
| `--depth` | Simulation depth for MCTS | 1 |
254+
| `--best-of-n` | Number of samples for best_of_n approach | 3 |
255+
| `--model` | OpenAI model to use | `"gpt-4o-mini"` |
256+
| `--base-url` | Base URL for OpenAI compatible endpoint | `""` |
257+
| `--rstar-max-depth` | Maximum depth for rStar algorithm | 3 |
258+
| `--rstar-num-rollouts` | Number of rollouts for rStar algorithm | 5 |
259+
| `--rstar-c` | Exploration constant for rStar algorithm | 1.4 |
260+
| `--n` | Number of final responses to be returned | 1 |
261+
| `--return-full-response` | Return the full response including the CoT with <thinking> tags | `False` |
262+
| `--port` | Specify the port to run the proxy | 8000 |
263+
| `--optillm-api-key` | Optional API key for client authentication to optillm | `""` |
264+
| `--cepo_bestofn_n` | Number of responses to be generated in best of n stage | 3 |
265+
| `--cepo_bestofn_temperature` | Temperature for verifier in best of n stage | 0.1 |
266+
| `--cepo_bestofn_max_tokens` | Maximum number of tokens for verifier in best of n stage | 4096 |
267+
| `--cepo_bestofn_rating_type` | Type of rating in best of n stage ("absolute" or "pairwise") | `"absolute"` |
268+
| `--cepo_planning_n` | Number of plans generated in planning stage | 3 |
269+
| `--cepo_planning_m` | Number of attempts to generate n plans in planning stage | 6 |
270+
| `--cepo_planning_temperature_step1` | Temperature for generator in step 1 of planning stage | 0.55 |
271+
| `--cepo_planning_temperature_step2` | Temperature for generator in step 2 of planning stage | 0.25 |
272+
| `--cepo_planning_temperature_step3` | Temperature for generator in step 3 of planning stage | 0.1 |
273+
| `--cepo_planning_temperature_step4` | Temperature for generator in step 4 of planning stage | 0 |
274+
| `--cepo_planning_max_tokens_step1` | Maximum number of tokens in step 1 of planning stage | 4096 |
275+
| `--cepo_planning_max_tokens_step2` | Maximum number of tokens in step 2 of planning stage | 4096 |
276+
| `--cepo_planning_max_tokens_step3` | Maximum number of tokens in step 3 of planning stage | 4096 |
277+
| `--cepo_planning_max_tokens_step4` | Maximum number of tokens in step 4 of planning stage | 4096 |
278+
| `--cepo_print_output` | Whether to print the output of each stage | `False` |
279+
| `--cepo_config_file` | Path to CePO configuration file | None |
263280

264281
When using Docker, these can be set as environment variables prefixed with `OPTILLM_`.
265282

@@ -308,6 +325,15 @@ Authorization: Bearer your_secret_api_key
308325

309326
## SOTA results on benchmarks with optillm
310327

328+
### CePO on math and code benchmarks
329+
330+
| Method | Math-L5 | MMLU-Pro (Math) | GPQA | CRUX | LiveCodeBench (pass@1) | Simple QA |
331+
| -------------------------: | :-----: | :-------------: | :--: | :--: | :--------------------: | :-------: |
332+
| Llama 3.1 70B | 41.6 | 72.9 | 41.7 | 64.2 | 24.5 | 14.7 |
333+
| Llama 3.3 70B | 51.0 | 78.6 | 49.1 | 72.6 | 27.1 | 20.9 |
334+
| Llama 3.1 405B | 49.8 | 79.2 | 50.7 | 73.0 | 31.8 | 13.5 |
335+
| CePO (using Llama 3.3 70B) | 69.6 | 84.8 | 55.5 | 80.1 | 31.9 | 22.6 |
336+
311337
### coc-claude-3-5-sonnet-20241022 on AIME 2024 pass@1 (Nov 2024)
312338

313339
| Model | Score |

optillm.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import secrets
55
from flask import Flask, request, jsonify
6+
from cerebras.cloud.sdk import Cerebras
67
from openai import AzureOpenAI, OpenAI
78
from flask import Response
89
import json
@@ -13,6 +14,7 @@
1314
from concurrent.futures import ThreadPoolExecutor
1415
from typing import Tuple, Optional, Union, Dict, Any, List
1516
from importlib.metadata import version
17+
from dataclasses import fields
1618

1719
# Import approach modules
1820
from optillm.mcts import chat_with_mcts
@@ -27,6 +29,7 @@
2729
from optillm.plansearch import plansearch
2830
from optillm.leap import leap
2931
from optillm.reread import re2_approach
32+
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
3033

3134
# Setup logging
3235
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -50,7 +53,14 @@ def get_config():
5053
from optillm.inference import create_inference_client
5154
API_KEY = os.environ.get("OPTILLM_API_KEY")
5255
default_client = create_inference_client()
53-
# OpenAI, Azure, or LiteLLM API configuration
56+
# Cerebras, OpenAI, Azure, or LiteLLM API configuration
57+
elif os.environ.get("CEREBRAS_API_KEY"):
58+
API_KEY = os.environ.get("CEREBRAS_API_KEY")
59+
base_url = server_config['base_url']
60+
if base_url != "":
61+
default_client = Cerebras(api_key=API_KEY, base_url=base_url)
62+
else:
63+
default_client = Cerebras(api_key=API_KEY)
5464
elif os.environ.get("OPENAI_API_KEY"):
5565
API_KEY = os.environ.get("OPENAI_API_KEY")
5666
base_url = server_config['base_url']
@@ -104,7 +114,7 @@ def get_config():
104114

105115
# List of known approaches
106116
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
107-
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
117+
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2", "cepo"]
108118

109119
plugin_approaches = {}
110120

@@ -124,7 +134,7 @@ def none_approach(
124134
model: Model identifier
125135
original_messages: Original messages from the request
126136
**kwargs: Additional parameters to pass through
127-
137+
128138
Returns:
129139
Dict[str, Any]: Full OpenAI API response
130140
"""
@@ -282,6 +292,8 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
282292
return leap(system_prompt, initial_query, client, model)
283293
elif approach == 're2':
284294
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
295+
elif approach == 'cepo':
296+
return cepo(system_prompt, initial_query, client, model, cepo_config)
285297
elif approach in plugin_approaches:
286298
return plugin_approaches[approach](system_prompt, initial_query, client, model)
287299
else:
@@ -690,6 +702,12 @@ def parse_args():
690702
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
691703
help="Base url for OpenAI compatible endpoint")
692704

705+
# Special handling of all the CePO Configurations
706+
for field in fields(CepoConfig):
707+
parser.add_argument(f"--cepo_{field.name}", dest=f"cepo_{field.name}", type=field.type, default=None, help=f"CePO configuration for {field.name}")
708+
709+
parser.add_argument(f"--cepo_config_file", dest=f"cepo_config_file", type=str, default="./optillm/cepo/configs/cepo_config.yaml", help="Path to CePO configuration file")
710+
693711
args = parser.parse_args()
694712

695713
# Convert argument names to match server_config keys
@@ -703,6 +721,7 @@ def parse_args():
703721

704722
def main():
705723
global server_config
724+
global cepo_config
706725
# Call this function at the start of main()
707726
args = parse_args()
708727
# Update server_config with all argument values
@@ -717,6 +736,11 @@ def main():
717736
if logging_level in logging_levels.keys():
718737
logger.setLevel(logging_levels[logging_level])
719738

739+
# set and log the cepo configs
740+
cepo_config = init_cepo_config(server_config)
741+
if args.approach == 'cepo':
742+
logger.info(f"CePO Config: {cepo_config}")
743+
720744
logger.info(f"Starting server with approach: {server_config['approach']}")
721745
server_config_clean = server_config.copy()
722746
if server_config_clean['optillm_api_key']:

optillm/cepo/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# The Cerebras Planning and Optimization (CePO) Method
2+
3+
CePO is an inference-time computation method designed to enhance the accuracy of large language models (LLMs) on tasks requiring reasoning and planning, such as solving math or coding problems. It integrates several advanced techniques, including Best of N, Chain of Thought (CoT), Self-Reflection, Self-Improvement, and Prompt Engineering.
4+
5+
If you have any questions or want to contribute, please reach out to us on [cerebras.ai/discord](cerebras.ai/discord)
6+
7+
## CePO Methodology
8+
9+
In CePO, the Best of N technique is applied to `bestofn_n` solution candidates. Each solution is generated through the following four steps:
10+
11+
**Step 1**: Plan Generation
12+
The model generates a detailed, step-by-step plan to solve the problem, along with its confidence level for each step.
13+
14+
**Step 2**: Initial Solution
15+
Using the plan from Step 1, the model produces an initial solution.
16+
17+
Steps 1 and 2 are repeated `planning_n` times to generate multiple solution proposals.
18+
If the model exceeds the token budget during Step 1 or 2, the plan/solution is marked as incomplete, rejected, and regenerated. A maximum of `planning_m` attempts is made to generate `planning_n` valid proposals.
19+
20+
**Step 3**: Plan Refinement
21+
The model reviews all generated solution proposals and their associated plans, identifying inconsistencies. Based on this analysis, a refined, final step-by-step plan is constructed.
22+
23+
**Step 4**: Final Solution
24+
The model uses the refined plan from Step 3 to produce the final answer.
25+
26+
## CePO Current Status
27+
28+
This project is a work in progress, and the provided code is in an early experimental stage. While the proposed approach works well across the benchmarks we tested, further improvements can be achieved by task-specific customizations to prompts.
29+
30+
## CePO Ablation studies
31+
32+
We conducted ablation studies to evaluate the impact of various hyperparameters in the CePO framework. Our results indicate that the chosen hyperparameter settings strike a good balance between computational cost and accuracy.
33+
34+
Interestingly, the self-critique and quality improvement capabilities of existing off-the-shelf models do not always scale proportionally with increased inference compute. Addressing this limitation remains a key focus, and we plan to explore custom model fine-tuning as a potential solution in the future.
35+
36+
| bestofn_n | planning_n | planning_m | bestofn_rating_type | Math-L5 | MMLU-Pro (Math) | GPQA | CRUX | Comments |
37+
| :-------: | :--------: | :--------: | :-----------------: | :-----: | :-------------: | :---: | :---: | :------------- |
38+
| 3 | 3 | 6 | absolute | 69.6 | 84.8 | 55.5 | 80.1 | Default config |
39+
| 3 | 3 | 6 | pairwise | 67.7 | 83.5 | 55.6 | 79.8 | |
40+
| 3 | 2 | 5 | absolute | 67.1 | 85.1 | 55.1 | 79.0 | |
41+
| 3 | 5 | 8 | absolute | 69.4 | 84.3 | 55.6 | 81.1 | |
42+
| 5 | 3 | 6 | absolute | 68.7 | 85.4 | 54.8 | 79.9 | |
43+
| 7 | 3 | 6 | absolute | 69.6 | 82.8 | 54.7 | 78.4 | |
44+
| 9 | 3 | 6 | absolute | 68.9 | 83.4 | 55.7 | 80.6 | |

0 commit comments

Comments
 (0)