Skip to content

Commit 2536579

Browse files
committed
add support for conversation logging
1 parent 7b304eb commit 2536579

19 files changed

+2171
-315
lines changed

optillm.py

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import secrets
55
import time
6+
from pathlib import Path
67
from flask import Flask, request, jsonify
78
from cerebras.cloud.sdk import Cerebras
89
from openai import AzureOpenAI, OpenAI
@@ -32,6 +33,7 @@
3233
from optillm.reread import re2_approach
3334
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
3435
from optillm.batching import RequestBatcher, BatchingError
36+
from optillm.conversation_logger import ConversationLogger
3537

3638
# Setup logging
3739
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -51,6 +53,9 @@
5153
# Global request batcher (initialized in main() if batch mode enabled)
5254
request_batcher = None
5355

56+
# Global conversation logger (initialized in main() if logging enabled)
57+
conversation_logger = None
58+
5459
def get_config():
5560
API_KEY = None
5661
if os.environ.get("OPTILLM_API_KEY"):
@@ -196,17 +201,17 @@ def none_approach(
196201
client: Any,
197202
model: str,
198203
original_messages: List[Dict[str, str]],
204+
request_id: str = None,
199205
**kwargs
200206
) -> Dict[str, Any]:
201207
"""
202208
Direct proxy approach that passes through all parameters to the underlying endpoint.
203209
204210
Args:
205-
system_prompt: System prompt text (unused)
206-
initial_query: Initial query/conversation (unused)
207211
client: OpenAI client instance
208212
model: Model identifier
209213
original_messages: Original messages from the request
214+
request_id: Optional request ID for conversation logging
210215
**kwargs: Additional parameters to pass through
211216
212217
Returns:
@@ -220,6 +225,13 @@ def none_approach(
220225
# Normalize message content to ensure it's always string
221226
normalized_messages = normalize_message_content(original_messages)
222227

228+
# Prepare request data for logging
229+
provider_request = {
230+
"model": model,
231+
"messages": normalized_messages,
232+
**kwargs
233+
}
234+
223235
# Make the direct completion call with normalized messages and parameters
224236
response = client.chat.completions.create(
225237
model=model,
@@ -228,11 +240,18 @@ def none_approach(
228240
)
229241

230242
# Convert to dict if it's not already
231-
if hasattr(response, 'model_dump'):
232-
return response.model_dump()
233-
return response
243+
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
244+
245+
# Log the provider call if conversation logging is enabled
246+
if conversation_logger and request_id:
247+
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
248+
249+
return response_dict
234250

235251
except Exception as e:
252+
# Log error if conversation logging is enabled
253+
if conversation_logger and request_id:
254+
conversation_logger.log_error(request_id, f"Error in none approach: {str(e)}")
236255
logger.error(f"Error in none approach: {str(e)}")
237256
raise
238257

@@ -345,7 +364,7 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
345364

346365
return operation, approaches, actual_model
347366

348-
def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None):
367+
def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None, request_id: str = None):
349368
if approach in known_approaches:
350369
if approach == 'none':
351370
# Extract kwargs from the request data
@@ -356,41 +375,41 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
356375
# Copy all parameters except 'stream', 'model' and 'messages'
357376
kwargs = {k: v for k, v in data.items()
358377
if k not in ['model', 'messages', 'stream', 'optillm_approach']}
359-
response = none_approach(original_messages=messages, client=client, model=model, **kwargs)
378+
response = none_approach(original_messages=messages, client=client, model=model, request_id=request_id, **kwargs)
360379
# For none approach, we return the response and a token count of 0
361380
# since the full token count is already in the response
362381
return response, 0
363382
elif approach == 'mcts':
364383
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
365-
server_config['mcts_exploration'], server_config['mcts_depth'])
384+
server_config['mcts_exploration'], server_config['mcts_depth'], request_id)
366385
elif approach == 'bon':
367-
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
386+
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_id)
368387
elif approach == 'moa':
369-
return mixture_of_agents(system_prompt, initial_query, client, model)
388+
return mixture_of_agents(system_prompt, initial_query, client, model, request_id)
370389
elif approach == 'rto':
371-
return round_trip_optimization(system_prompt, initial_query, client, model)
390+
return round_trip_optimization(system_prompt, initial_query, client, model, request_id)
372391
elif approach == 'z3':
373-
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
392+
z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_id=request_id)
374393
return z3_solver.process_query(initial_query)
375394
elif approach == "self_consistency":
376-
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
395+
return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_id)
377396
elif approach == "pvg":
378-
return inference_time_pv_game(system_prompt, initial_query, client, model)
397+
return inference_time_pv_game(system_prompt, initial_query, client, model, request_id)
379398
elif approach == "rstar":
380399
rstar = RStar(system_prompt, client, model,
381400
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
382-
c=server_config['rstar_c'])
401+
c=server_config['rstar_c'], request_id=request_id)
383402
return rstar.solve(initial_query)
384403
elif approach == "cot_reflection":
385-
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config)
404+
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config, request_id=request_id)
386405
elif approach == 'plansearch':
387-
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
406+
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id)
388407
elif approach == 'leap':
389-
return leap(system_prompt, initial_query, client, model)
408+
return leap(system_prompt, initial_query, client, model, request_id)
390409
elif approach == 're2':
391-
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
410+
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id)
392411
elif approach == 'cepo':
393-
return cepo(system_prompt, initial_query, client, model, cepo_config)
412+
return cepo(system_prompt, initial_query, client, model, cepo_config, request_id)
394413
elif approach in plugin_approaches:
395414
# Check if the plugin accepts request_config
396415
plugin_func = plugin_approaches[approach]
@@ -445,7 +464,7 @@ async def run_approach(approach):
445464
return list(responses), sum(tokens)
446465

447466
def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str,
448-
request_config: dict = None) -> Tuple[Union[str, List[str]], int]:
467+
request_config: dict = None, request_id: str = None) -> Tuple[Union[str, List[str]], int]:
449468
"""
450469
Execute the pipeline n times and return n responses.
451470
@@ -466,7 +485,7 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init
466485

467486
for _ in range(n):
468487
if operation == 'SINGLE':
469-
response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
488+
response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id)
470489
elif operation == 'AND':
471490
response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config)
472491
elif operation == 'OR':
@@ -678,6 +697,21 @@ def proxy():
678697
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
679698
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')
680699

700+
# Start conversation logging if enabled
701+
request_id = None
702+
if conversation_logger and conversation_logger.enabled:
703+
request_id = conversation_logger.start_conversation(
704+
client_request={
705+
'messages': messages,
706+
'model': data.get('model', server_config['model']),
707+
'stream': stream,
708+
'n': n,
709+
**{k: v for k, v in data.items() if k not in {'messages', 'model', 'stream', 'n'}}
710+
},
711+
approach=approaches[0] if len(approaches) == 1 else f"{operation}({','.join(approaches)})",
712+
model=model
713+
)
714+
681715
if bearer_token != "" and bearer_token.startswith("sk-"):
682716
api_key = bearer_token
683717
if base_url != "":
@@ -718,10 +752,15 @@ def proxy():
718752

719753
if operation == 'SINGLE' and approaches[0] == 'none':
720754
# Pass through the request including the n parameter
721-
result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
755+
result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id)
722756

723757
logger.debug(f'Direct proxy response: {result}')
724758

759+
# Log the final response and finalize conversation logging
760+
if conversation_logger and request_id:
761+
conversation_logger.log_final_response(request_id, result)
762+
conversation_logger.finalize_conversation(request_id)
763+
725764
if stream:
726765
return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream')
727766
else :
@@ -732,9 +771,14 @@ def proxy():
732771
raise ValueError("'none' approach cannot be combined with other approaches")
733772

734773
# Handle non-none approaches with n attempts
735-
response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config)
774+
response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config, request_id)
736775

737776
except Exception as e:
777+
# Log error to conversation logger if enabled
778+
if conversation_logger and request_id:
779+
conversation_logger.log_error(request_id, str(e))
780+
conversation_logger.finalize_conversation(request_id)
781+
738782
logger.error(f"Error processing request: {str(e)}")
739783
return jsonify({"error": str(e)}), 500
740784

@@ -793,6 +837,11 @@ def proxy():
793837
'finish_reason': 'stop'
794838
})
795839

840+
# Log the final response and finalize conversation logging
841+
if conversation_logger and request_id:
842+
conversation_logger.log_final_response(request_id, response_data)
843+
conversation_logger.finalize_conversation(request_id)
844+
796845
logger.debug(f'API response: {response_data}')
797846
return jsonify(response_data), 200
798847

@@ -848,6 +897,8 @@ def parse_args():
848897
("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys())),
849898
("--launch-gui", "OPTILLM_LAUNCH_GUI", bool, False, "Launch a Gradio chat interface"),
850899
("--plugins-dir", "OPTILLM_PLUGINS_DIR", str, "", "Path to the plugins directory"),
900+
("--log-conversations", "OPTILLM_LOG_CONVERSATIONS", bool, False, "Enable conversation logging with full metadata"),
901+
("--conversation-log-dir", "OPTILLM_CONVERSATION_LOG_DIR", str, str(Path.home() / ".optillm" / "conversations"), "Directory to save conversation logs"),
851902
]
852903

853904
for arg, env, type_, default, help_text, *extra in args_env:
@@ -920,6 +971,7 @@ def main():
920971
global server_config
921972
global cepo_config
922973
global request_batcher
974+
global conversation_logger
923975
# Call this function at the start of main()
924976
args = parse_args()
925977
# Update server_config with all argument values
@@ -1075,6 +1127,15 @@ def process_batch_requests(batch_requests):
10751127
if logging_level in logging_levels.keys():
10761128
logger.setLevel(logging_levels[logging_level])
10771129

1130+
# Initialize conversation logger if enabled
1131+
global conversation_logger
1132+
conversation_logger = ConversationLogger(
1133+
log_dir=Path(server_config['conversation_log_dir']),
1134+
enabled=server_config['log_conversations']
1135+
)
1136+
if server_config['log_conversations']:
1137+
logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}")
1138+
10781139
# set and log the cepo configs
10791140
cepo_config = init_cepo_config(server_config)
10801141
if args.approach == 'cepo':

optillm/bon.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
2+
import optillm
23

34
logger = logging.getLogger(__name__)
45

5-
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
6+
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str:
67
bon_completion_tokens = 0
78

89
messages = [{"role": "system", "content": system_prompt},
@@ -12,13 +13,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
1213

1314
try:
1415
# Try to generate n completions in a single API call using n parameter
15-
response = client.chat.completions.create(
16-
model=model,
17-
messages=messages,
18-
max_tokens=4096,
19-
n=n,
20-
temperature=1
21-
)
16+
provider_request = {
17+
"model": model,
18+
"messages": messages,
19+
"max_tokens": 4096,
20+
"n": n,
21+
"temperature": 1
22+
}
23+
response = client.chat.completions.create(**provider_request)
24+
25+
# Log provider call
26+
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
27+
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
28+
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
29+
2230
completions = [choice.message.content for choice in response.choices]
2331
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
2432
bon_completion_tokens += response.usage.completion_tokens
@@ -30,12 +38,19 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
3038
# Fallback: Generate completions one by one in a loop
3139
for i in range(n):
3240
try:
33-
response = client.chat.completions.create(
34-
model=model,
35-
messages=messages,
36-
max_tokens=4096,
37-
temperature=1
38-
)
41+
provider_request = {
42+
"model": model,
43+
"messages": messages,
44+
"max_tokens": 4096,
45+
"temperature": 1
46+
}
47+
response = client.chat.completions.create(**provider_request)
48+
49+
# Log provider call
50+
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
51+
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
52+
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
53+
3954
completions.append(response.choices[0].message.content)
4055
bon_completion_tokens += response.usage.completion_tokens
4156
logger.debug(f"Generated completion {i+1}/{n}")
@@ -59,13 +74,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
5974
rating_messages.append({"role": "assistant", "content": completion})
6075
rating_messages.append({"role": "user", "content": "Rate the above response:"})
6176

62-
rating_response = client.chat.completions.create(
63-
model=model,
64-
messages=rating_messages,
65-
max_tokens=256,
66-
n=1,
67-
temperature=0.1
68-
)
77+
provider_request = {
78+
"model": model,
79+
"messages": rating_messages,
80+
"max_tokens": 256,
81+
"n": 1,
82+
"temperature": 0.1
83+
}
84+
rating_response = client.chat.completions.create(**provider_request)
85+
86+
# Log provider call
87+
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
88+
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
89+
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
90+
6991
bon_completion_tokens += rating_response.usage.completion_tokens
7092
try:
7193
rating = float(rating_response.choices[0].message.content.strip())

0 commit comments

Comments
 (0)