Skip to content

Commit 7e02af3

Browse files
committed
Add request_config support to cot_reflection
The cot_reflection function now accepts a request_config parameter, allowing dynamic configuration of temperature and max_tokens for API calls. Defaults are set to temperature=0.6 and max_tokens=4096 if not provided.
1 parent 50f5f7a commit 7e02af3

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

optillm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
331331
c=server_config['rstar_c'])
332332
return rstar.solve(initial_query)
333333
elif approach == "cot_reflection":
334-
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
334+
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config)
335335
elif approach == 'plansearch':
336336
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
337337
elif approach == 'leap':

optillm/cot_reflection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@
33

44
logger = logging.getLogger(__name__)
55

6-
def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False):
6+
def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False, request_config: dict = None):
77
cot_completion_tokens = 0
8+
9+
# Extract temperature and max_tokens from request_config with defaults
10+
temperature = 0.6 # Default to 0.6 as requested
11+
max_tokens = 4096 # Default to 4096 as requested
12+
13+
if request_config:
14+
temperature = request_config.get('temperature', temperature)
15+
max_tokens = request_config.get('max_tokens', max_tokens)
816
cot_prompt = f"""
917
{system_prompt}
1018
@@ -32,15 +40,15 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
3240
</output>
3341
"""
3442

35-
# Make the API call
43+
# Make the API call using user-provided or default parameters
3644
response = client.chat.completions.create(
3745
model=model,
3846
messages=[
3947
{"role": "system", "content": cot_prompt},
4048
{"role": "user", "content": initial_query}
4149
],
42-
temperature=0.7,
43-
max_tokens=4096
50+
temperature=temperature,
51+
max_tokens=max_tokens
4452
)
4553

4654
# Extract the full response

0 commit comments

Comments
 (0)