Skip to content

Commit 451f4bf

Browse files
authored
Merge pull request #102 from codelion/fix-args
Fix args
2 parents 572a0c5 + 90bef2e commit 451f4bf

File tree

3 files changed

+294
-3
lines changed

3 files changed

+294
-3
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ response = client.chat.completions.create(
218218
| Plugin | Slug | Description |
219219
| ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- |
220220
| Router | `router` | Uses the [optillm-bert-uncased](https://huggingface.co/codelion/optillm-bert-uncased) model to route requests to different approaches based on the user prompt |
221+
| Chain-of-Code | `coc` | Implements a chain of code approach that combines CoT with code execution and LLM based code simulation |
221222
| Memory | `memory` | Implements a short term memory layer, enables you to use unbounded context length with any LLM |
222223
| Privacy | `privacy` | Anonymize PII data in request and deanonymize it back to original value in response |
223224
| Read URLs | `readurls` | Reads all URLs found in the request, fetches the content at the URL and adds it to the context |
@@ -290,6 +291,20 @@ Authorization: Bearer your_secret_api_key
290291
```
291292
## SOTA results on benchmarks with optillm
292293

294+
### coc-claude-3-5-sonnet-20241022 on AIME 2024 pass@1 (Nov 2024)
295+
296+
| Model | Score |
297+
|-------|-----:|
298+
| o1-mini | 56.67 |
299+
| coc-claude-3-5-sonnet-20241022 | 46.67 |
300+
| coc-gemini/gemini-exp-1121 | 46.67 |
301+
| o1-preview | 40.00 |
302+
| f1-preview | 40.00 |
303+
| gemini-exp-1114 | 36.67 |
304+
| claude-3-5-sonnet-20241022 | 20.00 |
305+
| gemini-1.5-pro-002 | 20.00 |
306+
| gemini-1.5-flash-002 | 16.67 |
307+
293308
### readurls&memory-gpt-4o-mini on Google FRAMES Benchmark (Oct 2024)
294309
| Model | Accuracy |
295310
| ----- | -------- |
@@ -324,6 +339,7 @@ called patchflows. We saw huge performance gains across all the supported patchf
324339

325340
## References
326341

342+
- [Chain of Code: Reasoning with a Language Model-Augmented Code Emulator](https://arxiv.org/abs/2312.04474) - [Implementation](https://github.com/codelion/optillm/blob/main/optillm/plugins/coc_plugin.py)
327343
- [Entropy Based Sampling and Parallel CoT Decoding](https://github.com/xjdr-alt/entropix) - [Implementation](https://github.com/codelion/optillm/blob/main/optillm/entropy_decoding.py)
328344
- [Fact, Fetch, and Reason: A Unified Evaluation of Retrieval-Augmented Generation](https://arxiv.org/abs/2409.12941) - [Evaluation script](https://github.com/codelion/optillm/blob/main/scripts/eval_frames_benchmark.py)
329345
- [Writing in the Margins: Better Inference Pattern for Long Context Retrieval](https://www.arxiv.org/abs/2408.14906) - [Inspired the implementation of the memory plugin](https://github.com/codelion/optillm/blob/main/optillm/plugins/memory_plugin.py)

optillm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,10 @@ def proxy():
395395
model = data.get('model', server_config['model'])
396396

397397
optillm_approach = data.get('optillm_approach', server_config['approach'])
398+
logger.debug(data)
399+
server_config['mcts_depth'] = data.get('mcts_depth', server_config['mcts_depth'])
400+
server_config['mcts_exploration' ] = data.get('mcts_exploration', server_config['mcts_exploration'])
401+
server_config['mcts_simulations'] = data.get('mcts_simulations', server_config['mcts_simulations'])
398402

399403
system_prompt, initial_query, message_optillm_approach = parse_conversation(messages)
400404

@@ -522,7 +526,7 @@ def parse_args():
522526
# Define arguments and their corresponding environment variables
523527
args_env = [
524528
("--optillm-api-key", "OPTILLM_API_KEY", str, "", "Optional API key for client authentication to optillm"),
525-
("--approach", "OPTILLM_APPROACH", str, "auto", "Inference approach to use", known_approaches),
529+
("--approach", "OPTILLM_APPROACH", str, "auto", "Inference approach to use", known_approaches + list(plugin_approaches.keys())),
526530
("--mcts-simulations", "OPTILLM_SIMULATIONS", int, 2, "Number of MCTS simulations"),
527531
("--mcts-exploration", "OPTILLM_EXPLORATION", float, 0.2, "Exploration weight for MCTS"),
528532
("--mcts-depth", "OPTILLM_DEPTH", int, 1, "Simulation depth for MCTS"),
@@ -571,10 +575,10 @@ def parse_args():
571575

572576
def main():
573577
global server_config
574-
args = parse_args()
575-
576578
# Call this function at the start of main()
577579
load_plugins()
580+
args = parse_args()
581+
578582
# Update server_config with all argument values
579583
server_config.update(vars(args))
580584

optillm/plugins/coc_plugin.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
import re
2+
import logging
3+
from typing import Tuple, Dict, Any, List
4+
import ast
5+
import traceback
6+
import math
7+
import importlib
8+
import json
9+
10+
logger = logging.getLogger(__name__)
11+
12+
# Plugin identifier
13+
SLUG = "coc"
14+
15+
# Maximum attempts to fix code
16+
MAX_FIX_ATTEMPTS = 3
17+
18+
# List of allowed modules for execution
19+
ALLOWED_MODULES = {
20+
'math': math,
21+
'numpy': 'numpy', # String indicates module should be imported in execution context
22+
}
23+
24+
# Initial code generation prompt
25+
CHAIN_OF_CODE_PROMPT = '''
26+
Write Python code to solve this problem. The code should:
27+
1. Break down the problem into clear computational steps
28+
2. Use standard Python features and math operations
29+
3. Store the final result in a variable named 'answer'
30+
4. Include error handling where appropriate
31+
5. Be complete and executable
32+
33+
Format your response using:
34+
```python
35+
[Your complete Python program here]
36+
```
37+
'''
38+
39+
# Code fix prompt
40+
CODE_FIX_PROMPT = '''
41+
The following Python code failed to execute. Fix the code to make it work.
42+
Original code:
43+
```python
44+
{code}
45+
```
46+
47+
Error encountered:
48+
{error}
49+
50+
Please provide a complete, fixed version of the code that:
51+
1. Addresses the error message
52+
2. Maintains the same logic and approach
53+
3. Stores the final result in 'answer'
54+
4. Is complete and executable
55+
56+
Return only the fixed code in a code block:
57+
```python
58+
[Your fixed code here]
59+
```
60+
'''
61+
62+
# Simulation prompt
63+
SIMULATION_PROMPT = '''
64+
The following Python code could not be executed directly. Analyze the code and determine what the answer would be.
65+
Pay special attention to:
66+
1. The core computational logic, ignoring any visualization or display code
67+
2. The key mathematical operations that determine the final answer
68+
3. Any logic that affects the 'answer' variable
69+
70+
Code to analyze:
71+
```python
72+
{code}
73+
```
74+
75+
Runtime error encountered:
76+
{error}
77+
78+
Return ONLY the final value that would be in the 'answer' variable. Return just the value, no explanations.
79+
'''
80+
81+
def extract_code_blocks(text: str) -> List[str]:
82+
"""Extract Python code blocks from text."""
83+
pattern = r'```python\s*(.*?)\s*```'
84+
matches = re.findall(pattern, text, re.DOTALL)
85+
blocks = [m.strip() for m in matches]
86+
logger.info(f"Extracted {len(blocks)} code blocks")
87+
for i, block in enumerate(blocks):
88+
logger.info(f"Code block {i+1}:\n{block}")
89+
return blocks
90+
91+
def sanitize_code(code: str) -> str:
92+
"""Prepare code for execution by adding necessary imports and safety checks."""
93+
# Add standard imports
94+
imports = "\n".join(f"import {mod}" for mod in ALLOWED_MODULES)
95+
96+
# Remove or modify problematic visualization code
97+
lines = code.split('\n')
98+
safe_lines = []
99+
for line in lines:
100+
# Skip matplotlib-related imports and plotting commands
101+
if any(x in line.lower() for x in ['matplotlib', 'plt.', '.plot(', '.show(', 'figure', 'subplot']):
102+
continue
103+
# Keep the line if it's not visualization-related
104+
safe_lines.append(line)
105+
106+
safe_code = '\n'.join(safe_lines)
107+
108+
# Add safety wrapper
109+
wrapper = f"""
110+
{imports}
111+
112+
def safe_execute():
113+
import numpy as np # Always allow numpy
114+
{safe_code.replace('\n', '\n ')}
115+
return answer if 'answer' in locals() else None
116+
117+
result = safe_execute()
118+
answer = result
119+
"""
120+
return wrapper
121+
122+
def execute_code(code: str) -> Tuple[Any, str]:
123+
"""Attempt to execute the code and return result or error."""
124+
logger.info("Attempting to execute code")
125+
logger.info(f"Code:\n{code}")
126+
127+
try:
128+
# Create a clean environment
129+
execution_env = {}
130+
131+
# Execute the code as-is
132+
exec(code, execution_env)
133+
134+
# Look for answer variable
135+
if 'answer' in execution_env:
136+
answer = execution_env['answer']
137+
logger.info(f"Execution successful. Answer: {answer}")
138+
return answer, None
139+
else:
140+
error = "Code executed but did not produce an answer variable"
141+
logger.warning(error)
142+
return None, error
143+
144+
except Exception as e:
145+
error = str(e)
146+
logger.error(f"Execution failed: {error}")
147+
return None, error
148+
149+
def generate_fixed_code(original_code: str, error: str, client, model: str) -> Tuple[str, int]:
150+
"""Ask LLM to fix the broken code."""
151+
logger.info("Requesting code fix from LLM")
152+
logger.info(f"Original error: {error}")
153+
154+
response = client.chat.completions.create(
155+
model=model,
156+
messages=[
157+
{"role": "system", "content": CODE_FIX_PROMPT.format(
158+
code=original_code, error=error)},
159+
{"role": "user", "content": "Fix the code to make it work."}
160+
],
161+
temperature=0.2
162+
)
163+
164+
fixed_code = response.choices[0].message.content
165+
code_blocks = extract_code_blocks(fixed_code)
166+
167+
if code_blocks:
168+
logger.info("Received fixed code from LLM")
169+
return code_blocks[0], response.usage.completion_tokens
170+
else:
171+
logger.warning("No code block found in LLM response")
172+
return None, response.usage.completion_tokens
173+
174+
def simulate_execution(code: str, error: str, client, model: str) -> Tuple[Any, int]:
175+
"""Ask LLM to simulate code execution."""
176+
logger.info("Attempting code simulation with LLM")
177+
178+
response = client.chat.completions.create(
179+
model=model,
180+
messages=[
181+
{"role": "system", "content": SIMULATION_PROMPT.format(
182+
code=code, error=error)},
183+
{"role": "user", "content": "Simulate this code and return the final answer value."}
184+
],
185+
temperature=0.2
186+
)
187+
188+
try:
189+
result = response.choices[0].message.content.strip()
190+
# Try to convert to appropriate type
191+
try:
192+
answer = ast.literal_eval(result)
193+
except:
194+
answer = result
195+
logger.info(f"Simulation successful. Result: {answer}")
196+
return answer, response.usage.completion_tokens
197+
except Exception as e:
198+
logger.error(f"Failed to parse simulation result: {str(e)}")
199+
return None, response.usage.completion_tokens
200+
201+
def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
202+
"""Main Chain of Code execution function."""
203+
logger.info("Starting Chain of Code execution")
204+
logger.info(f"Query: {initial_query}")
205+
206+
# Initial code generation
207+
messages = [
208+
{"role": "system", "content": system_prompt + "\n" + CHAIN_OF_CODE_PROMPT},
209+
{"role": "user", "content": initial_query}
210+
]
211+
212+
response = client.chat.completions.create(
213+
model=model,
214+
messages=messages,
215+
temperature=0.7
216+
)
217+
total_tokens = response.usage.completion_tokens
218+
219+
# Extract initial code
220+
code_blocks = extract_code_blocks(response.choices[0].message.content)
221+
if not code_blocks:
222+
logger.warning("No code blocks found in response")
223+
return response.choices[0].message.content, total_tokens
224+
225+
current_code = code_blocks[0]
226+
fix_attempts = 0
227+
last_error = None
228+
229+
# Strategy 1: Direct execution and fix attempts
230+
while fix_attempts < MAX_FIX_ATTEMPTS:
231+
fix_attempts += 1
232+
logger.info(f"Execution attempt {fix_attempts}/{MAX_FIX_ATTEMPTS}")
233+
234+
# Try to execute current code
235+
answer, error = execute_code(current_code)
236+
237+
# If successful, return the answer
238+
if error is None:
239+
logger.info(f"Successful execution on attempt {fix_attempts}")
240+
return str(answer), total_tokens
241+
242+
last_error = error
243+
244+
# If we hit max attempts, break to try simulation
245+
if fix_attempts >= MAX_FIX_ATTEMPTS:
246+
logger.warning(f"Failed after {fix_attempts} fix attempts")
247+
break
248+
249+
# Otherwise, try to get fixed code from LLM
250+
logger.info(f"Requesting code fix, attempt {fix_attempts}")
251+
fixed_code, fix_tokens = generate_fixed_code(current_code, error, client, model)
252+
total_tokens += fix_tokens
253+
254+
if fixed_code:
255+
current_code = fixed_code
256+
else:
257+
logger.error("Failed to get fixed code from LLM")
258+
break
259+
260+
# Strategy 2: If all execution attempts failed, try simulation
261+
logger.info("All execution attempts failed, trying simulation")
262+
simulated_answer, sim_tokens = simulate_execution(current_code, last_error, client, model)
263+
total_tokens += sim_tokens
264+
265+
if simulated_answer is not None:
266+
logger.info("Successfully got answer from simulation")
267+
return str(simulated_answer), total_tokens
268+
269+
# If we get here, everything failed
270+
logger.warning("All strategies failed")
271+
return f"Error: Could not solve problem after all attempts. Last error: {last_error}", total_tokens

0 commit comments

Comments
 (0)