diff --git a/README.md b/README.md
index e4fe2460..5505427d 100644
--- a/README.md
+++ b/README.md
@@ -287,7 +287,7 @@ and then use the same in your OpenAI client. You can pass any HuggingFace model
with your HuggingFace key. We also support adding any number of LoRAs on top of the model by using the `+` separator.
E.g. The following code loads the base model `meta-llama/Llama-3.2-1B-Instruct` and then adds two LoRAs on top - `patched-codes/Llama-3.2-1B-FixVulns` and `patched-codes/Llama-3.2-1B-FastApply`.
-You can specify which LoRA to use using the `active_adapter` param in `extra_args` field of OpenAI SDK client. By default we will load the last specified adapter.
+You can specify which LoRA to use using the `active_adapter` param in `extra_body` field of OpenAI SDK client. By default we will load the last specified adapter.
```python
OPENAI_BASE_URL = "http://localhost:8000/v1"
@@ -748,4 +748,4 @@ If you use this library in your research, please cite:
ā Star us on GitHub if you find OptiLLM useful!
-
\ No newline at end of file
+
diff --git a/optillm.py b/optillm.py
index 70582c18..5df85030 100644
--- a/optillm.py
+++ b/optillm.py
@@ -3,6 +3,7 @@
import os
import secrets
import time
+from pathlib import Path
from flask import Flask, request, jsonify
from cerebras.cloud.sdk import Cerebras
from openai import AzureOpenAI, OpenAI
@@ -32,6 +33,8 @@
from optillm.reread import re2_approach
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
from optillm.batching import RequestBatcher, BatchingError
+from optillm.conversation_logger import ConversationLogger
+import optillm.conversation_logger
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -51,6 +54,9 @@
# Global request batcher (initialized in main() if batch mode enabled)
request_batcher = None
+# Global conversation logger (initialized in main() if logging enabled)
+conversation_logger = None
+
def get_config():
API_KEY = None
if os.environ.get("OPTILLM_API_KEY"):
@@ -196,17 +202,17 @@ def none_approach(
client: Any,
model: str,
original_messages: List[Dict[str, str]],
+ request_id: str = None,
**kwargs
) -> Dict[str, Any]:
"""
Direct proxy approach that passes through all parameters to the underlying endpoint.
Args:
- system_prompt: System prompt text (unused)
- initial_query: Initial query/conversation (unused)
client: OpenAI client instance
model: Model identifier
original_messages: Original messages from the request
+ request_id: Optional request ID for conversation logging
**kwargs: Additional parameters to pass through
Returns:
@@ -220,6 +226,13 @@ def none_approach(
# Normalize message content to ensure it's always string
normalized_messages = normalize_message_content(original_messages)
+ # Prepare request data for logging
+ provider_request = {
+ "model": model,
+ "messages": normalized_messages,
+ **kwargs
+ }
+
# Make the direct completion call with normalized messages and parameters
response = client.chat.completions.create(
model=model,
@@ -228,11 +241,18 @@ def none_approach(
)
# Convert to dict if it's not already
- if hasattr(response, 'model_dump'):
- return response.model_dump()
- return response
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+
+ # Log the provider call if conversation logging is enabled
+ if conversation_logger and request_id:
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
+ return response_dict
except Exception as e:
+ # Log error if conversation logging is enabled
+ if conversation_logger and request_id:
+ conversation_logger.log_error(request_id, f"Error in none approach: {str(e)}")
logger.error(f"Error in none approach: {str(e)}")
raise
@@ -345,52 +365,58 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
return operation, approaches, actual_model
-def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None):
+def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None, request_id: str = None):
if approach in known_approaches:
if approach == 'none':
- # Extract kwargs from the request data
- kwargs = {}
- if hasattr(request, 'json'):
- data = request.get_json()
- messages = data.get('messages', [])
- # Copy all parameters except 'stream', 'model' and 'messages'
- kwargs = {k: v for k, v in data.items()
- if k not in ['model', 'messages', 'stream', 'optillm_approach']}
- response = none_approach(original_messages=messages, client=client, model=model, **kwargs)
+ # Use the request_config that was already prepared and passed to this function
+ kwargs = request_config.copy() if request_config else {}
+
+ # Remove items that are handled separately by the framework
+ kwargs.pop('n', None) # n is handled by execute_n_times
+ kwargs.pop('stream', None) # stream is handled by proxy()
+
+ # Reconstruct original messages from system_prompt and initial_query
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ if initial_query:
+ messages.append({"role": "user", "content": initial_query})
+
+ response = none_approach(original_messages=messages, client=client, model=model, request_id=request_id, **kwargs)
# For none approach, we return the response and a token count of 0
# since the full token count is already in the response
return response, 0
elif approach == 'mcts':
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
- server_config['mcts_exploration'], server_config['mcts_depth'])
+ server_config['mcts_exploration'], server_config['mcts_depth'], request_id)
elif approach == 'bon':
- return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
+ return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_id)
elif approach == 'moa':
- return mixture_of_agents(system_prompt, initial_query, client, model)
+ return mixture_of_agents(system_prompt, initial_query, client, model, request_id)
elif approach == 'rto':
- return round_trip_optimization(system_prompt, initial_query, client, model)
+ return round_trip_optimization(system_prompt, initial_query, client, model, request_id)
elif approach == 'z3':
- z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
+ z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_id=request_id)
return z3_solver.process_query(initial_query)
elif approach == "self_consistency":
- return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
+ return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_id)
elif approach == "pvg":
- return inference_time_pv_game(system_prompt, initial_query, client, model)
+ return inference_time_pv_game(system_prompt, initial_query, client, model, request_id)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
- c=server_config['rstar_c'])
+ c=server_config['rstar_c'], request_id=request_id)
return rstar.solve(initial_query)
elif approach == "cot_reflection":
- return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config)
+ return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config, request_id=request_id)
elif approach == 'plansearch':
- return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
+ return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id)
elif approach == 'leap':
- return leap(system_prompt, initial_query, client, model)
+ return leap(system_prompt, initial_query, client, model, request_id)
elif approach == 're2':
- return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
+ return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id)
elif approach == 'cepo':
- return cepo(system_prompt, initial_query, client, model, cepo_config)
+ return cepo(system_prompt, initial_query, client, model, cepo_config, request_id)
elif approach in plugin_approaches:
# Check if the plugin accepts request_config
plugin_func = plugin_approaches[approach]
@@ -445,7 +471,7 @@ async def run_approach(approach):
return list(responses), sum(tokens)
def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str,
- request_config: dict = None) -> Tuple[Union[str, List[str]], int]:
+ request_config: dict = None, request_id: str = None) -> Tuple[Union[str, List[str]], int]:
"""
Execute the pipeline n times and return n responses.
@@ -466,7 +492,7 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init
for _ in range(n):
if operation == 'SINGLE':
- response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
+ response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id)
elif operation == 'AND':
response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config)
elif operation == 'OR':
@@ -676,7 +702,27 @@ def proxy():
default_client, api_key = get_config()
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
- logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')
+
+ # Start conversation logging if enabled
+ request_id = None
+ if conversation_logger and conversation_logger.enabled:
+ request_id = conversation_logger.start_conversation(
+ client_request={
+ 'messages': messages,
+ 'model': data.get('model', server_config['model']),
+ 'stream': stream,
+ 'n': n,
+ **{k: v for k, v in data.items() if k not in {'messages', 'model', 'stream', 'n'}}
+ },
+ approach=approaches[0] if len(approaches) == 1 else f"{operation}({','.join(approaches)})",
+ model=model
+ )
+
+ # Log approach and request start with ID for terminal monitoring
+ request_id_str = f' [Request: {request_id}]' if request_id else ''
+ logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}{request_id_str}')
+ if request_id:
+ logger.info(f'Request {request_id}: Starting processing')
if bearer_token != "" and bearer_token.startswith("sk-"):
api_key = bearer_token
@@ -718,13 +764,22 @@ def proxy():
if operation == 'SINGLE' and approaches[0] == 'none':
# Pass through the request including the n parameter
- result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
+ result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id)
logger.debug(f'Direct proxy response: {result}')
+ # Log the final response and finalize conversation logging
+ if conversation_logger and request_id:
+ conversation_logger.log_final_response(request_id, result)
+ conversation_logger.finalize_conversation(request_id)
+
if stream:
+ if request_id:
+ logger.info(f'Request {request_id}: Completed (streaming response)')
return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream')
else :
+ if request_id:
+ logger.info(f'Request {request_id}: Completed')
return jsonify(result), 200
elif operation == 'AND' or operation == 'OR':
@@ -732,10 +787,16 @@ def proxy():
raise ValueError("'none' approach cannot be combined with other approaches")
# Handle non-none approaches with n attempts
- response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config)
+ response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config, request_id)
except Exception as e:
- logger.error(f"Error processing request: {str(e)}")
+ # Log error to conversation logger if enabled
+ if conversation_logger and request_id:
+ conversation_logger.log_error(request_id, str(e))
+ conversation_logger.finalize_conversation(request_id)
+
+ request_id_str = f' {request_id}' if request_id else ''
+ logger.error(f"Error processing request{request_id_str}: {str(e)}")
return jsonify({"error": str(e)}), 500
# Convert tagged conversation to messages format if needed
@@ -793,7 +854,14 @@ def proxy():
'finish_reason': 'stop'
})
+ # Log the final response and finalize conversation logging
+ if conversation_logger and request_id:
+ conversation_logger.log_final_response(request_id, response_data)
+ conversation_logger.finalize_conversation(request_id)
+
logger.debug(f'API response: {response_data}')
+ if request_id:
+ logger.info(f'Request {request_id}: Completed')
return jsonify(response_data), 200
@app.route('/v1/models', methods=['GET'])
@@ -848,6 +916,8 @@ def parse_args():
("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys())),
("--launch-gui", "OPTILLM_LAUNCH_GUI", bool, False, "Launch a Gradio chat interface"),
("--plugins-dir", "OPTILLM_PLUGINS_DIR", str, "", "Path to the plugins directory"),
+ ("--log-conversations", "OPTILLM_LOG_CONVERSATIONS", bool, False, "Enable conversation logging with full metadata"),
+ ("--conversation-log-dir", "OPTILLM_CONVERSATION_LOG_DIR", str, str(Path.home() / ".optillm" / "conversations"), "Directory to save conversation logs"),
]
for arg, env, type_, default, help_text, *extra in args_env:
@@ -920,6 +990,7 @@ def main():
global server_config
global cepo_config
global request_batcher
+ global conversation_logger
# Call this function at the start of main()
args = parse_args()
# Update server_config with all argument values
@@ -1075,6 +1146,17 @@ def process_batch_requests(batch_requests):
if logging_level in logging_levels.keys():
logger.setLevel(logging_levels[logging_level])
+ # Initialize conversation logger if enabled
+ global conversation_logger
+ conversation_logger = ConversationLogger(
+ log_dir=Path(server_config['conversation_log_dir']),
+ enabled=server_config['log_conversations']
+ )
+ # Set the global logger instance for access from approach modules
+ optillm.conversation_logger.set_global_logger(conversation_logger)
+ if server_config['log_conversations']:
+ logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}")
+
# set and log the cepo configs
cepo_config = init_cepo_config(server_config)
if args.approach == 'cepo':
diff --git a/optillm/__init__.py b/optillm/__init__.py
index e7c0d493..cbb5521b 100644
--- a/optillm/__init__.py
+++ b/optillm/__init__.py
@@ -2,7 +2,7 @@
import os
# Version information
-__version__ = "0.1.28"
+__version__ = "0.2.0"
# Get the path to the root optillm.py
spec = util.spec_from_file_location(
diff --git a/optillm/bon.py b/optillm/bon.py
index 3da7d140..3e5885df 100644
--- a/optillm/bon.py
+++ b/optillm/bon.py
@@ -1,8 +1,10 @@
import logging
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
-def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
+def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str:
bon_completion_tokens = 0
messages = [{"role": "system", "content": system_prompt},
@@ -12,13 +14,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
try:
# Try to generate n completions in a single API call using n parameter
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=4096,
- n=n,
- temperature=1
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 4096,
+ "n": n,
+ "temperature": 1
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
completions = [choice.message.content for choice in response.choices]
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
bon_completion_tokens += response.usage.completion_tokens
@@ -30,12 +39,19 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
# Fallback: Generate completions one by one in a loop
for i in range(n):
try:
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=4096,
- temperature=1
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 4096,
+ "temperature": 1
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
completions.append(response.choices[0].message.content)
bon_completion_tokens += response.usage.completion_tokens
logger.debug(f"Generated completion {i+1}/{n}")
@@ -59,13 +75,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
rating_messages.append({"role": "assistant", "content": completion})
rating_messages.append({"role": "user", "content": "Rate the above response:"})
- rating_response = client.chat.completions.create(
- model=model,
- messages=rating_messages,
- max_tokens=256,
- n=1,
- temperature=0.1
- )
+ provider_request = {
+ "model": model,
+ "messages": rating_messages,
+ "max_tokens": 256,
+ "n": 1,
+ "temperature": 0.1
+ }
+ rating_response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if request_id:
+ response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
bon_completion_tokens += rating_response.usage.completion_tokens
try:
rating = float(rating_response.choices[0].message.content.strip())
diff --git a/optillm/cepo/cepo.py b/optillm/cepo/cepo.py
index f58d2694..3f098bcf 100644
--- a/optillm/cepo/cepo.py
+++ b/optillm/cepo/cepo.py
@@ -1,6 +1,8 @@
import re
import yaml
import json
+import optillm
+from optillm import conversation_logger
from dataclasses import dataclass
from typing import Literal, Any, Optional
@@ -58,7 +60,7 @@ def extract_question_only(task: str) -> str:
return question_only
-def generate_completion(system_prompt: str, task: str, client: Any, model: str, cepo_config: CepoConfig, approach: Optional[str] = None) -> str:
+def generate_completion(system_prompt: str, task: str, client: Any, model: str, cepo_config: CepoConfig, approach: Optional[str] = None, request_id: str = None) -> str:
"""
Generates a completion based on the provided system prompt and task.
@@ -94,13 +96,22 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
f"to execute it correctly. Here is the question:\n{question_only}\nRead the question again:\n\n{question_only}"
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}]
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=cepo_config.planning_max_tokens_step1,
- temperature=cepo_config.planning_temperature_step1,
- stream=False,
- )
+
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": cepo_config.planning_max_tokens_step1,
+ "temperature": cepo_config.planning_temperature_step1,
+ "stream": False,
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
completion_tokens += response.usage.completion_tokens
if response.choices[0].finish_reason == "length":
@@ -111,13 +122,22 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
content = f"Can you execute the above plan step-by-step to produce the final answer. "\
f"Be extra careful when executing steps where your confidence is lower."
messages.extend([{"role": "assistant", "content": response.choices[0].message.content}, {"role": "user", "content": content}])
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=cepo_config.planning_max_tokens_step2,
- temperature=cepo_config.planning_temperature_step2,
- stream=False,
- )
+
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": cepo_config.planning_max_tokens_step2,
+ "temperature": cepo_config.planning_temperature_step2,
+ "stream": False,
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
completion_tokens += response.usage.completion_tokens
if response.choices[0].finish_reason == "length":
@@ -154,13 +174,21 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
f"it and present a final step-by-step solution to the problem? Here is the question:\n{question_only}"
messages = [{"role": "assistant", "content": plans_message}, {"role": "user", "content": content}]
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=cepo_config.planning_max_tokens_step3,
- temperature=cepo_config.planning_temperature_step3,
- stream=False,
- )
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": cepo_config.planning_max_tokens_step3,
+ "temperature": cepo_config.planning_temperature_step3,
+ "stream": False,
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
final_solution = response.choices[0].message.content
completion_tokens += response.usage.completion_tokens
except (CerebrasBadRequestError, OpenAIBadRequestError) as e:
@@ -172,13 +200,21 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
content = f"Use your final solution from above to correctly answer the question. Here is the question:\n{task}"
messages = [{"role": "assistant", "content": final_solution}, {"role": "user", "content": content}]
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=cepo_config.planning_max_tokens_step4,
- temperature=cepo_config.planning_temperature_step4,
- stream=False,
- )
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": cepo_config.planning_max_tokens_step4,
+ "temperature": cepo_config.planning_temperature_step4,
+ "stream": False,
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
completion_tokens += response.usage.completion_tokens
cb_log["messages"] = messages
@@ -187,7 +223,7 @@ def generate_completion(system_prompt: str, task: str, client: Any, model: str,
return response.choices[0].message.content, completion_tokens, cb_log
-def generate_approaches(system_prompt: str, initial_query: str, num_approach: int, client: Any, model: str, cepo_config: CepoConfig, max_retry: int = 2) -> tuple[list[str], int]:
+def generate_approaches(system_prompt: str, initial_query: str, num_approach: int, client: Any, model: str, cepo_config: CepoConfig, max_retry: int = 2, request_id: str = None) -> tuple[list[str], int]:
completion_tokens = 0
question_only = extract_question_only(initial_query)
approaches = []
@@ -205,13 +241,21 @@ def generate_approaches(system_prompt: str, initial_query: str, num_approach: in
retries = 0
while retries < max_retry:
try:
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=cepo_config.planning_max_tokens_step0,
- temperature=cepo_config.planning_temperature_step0,
- stream=False,
- )
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": cepo_config.planning_max_tokens_step0,
+ "temperature": cepo_config.planning_temperature_step0,
+ "stream": False,
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
completion_tokens += response.usage.completion_tokens
completion = response.choices[0].message.content
@@ -265,6 +309,7 @@ def generate_n_completions(system_prompt: str, initial_query: str, client: Any,
client=client,
model=model,
cepo_config=cepo_config,
+ request_id=request_id
)
cb_log["approaches"] = approaches
completion_tokens += approach_completion_tokens
@@ -275,7 +320,7 @@ def generate_n_completions(system_prompt: str, initial_query: str, client: Any,
if cepo_config.print_output:
print(f"\nCePO: Generating completion {i + 1} out of {cepo_config.bestofn_n} \n")
approach = approaches[i] if approaches else None
- response_i, completion_tokens_i, cb_log_i = generate_completion(system_prompt, initial_query, client, model, cepo_config, approach)
+ response_i, completion_tokens_i, cb_log_i = generate_completion(system_prompt, initial_query, client, model, cepo_config, approach, request_id)
completions.append(response_i)
completion_tokens += completion_tokens_i
cb_log[f"completion_{i}_response"] = response_i
@@ -285,7 +330,7 @@ def generate_n_completions(system_prompt: str, initial_query: str, client: Any,
return completions, completion_tokens, cb_log
-def rate_completions_absolute(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict) -> tuple[str, int, dict]:
+def rate_completions_absolute(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict, request_id: str = None) -> tuple[str, int, dict]:
"""
Rates completions for the Best of N step of CePO. Each completion is rated on a scale of 1 to 10 individually.
@@ -327,12 +372,20 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
"by strictly following this format: \"Explanation: \n\nRating: [[rating]]\"."
rating_messages.append({"role": "user", "content": content})
- rating_response = client.chat.completions.create(
- model=model,
- messages=rating_messages,
- max_tokens=cepo_config.bestofn_max_tokens,
- temperature=cepo_config.bestofn_temperature
- )
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": rating_messages,
+ "max_tokens": cepo_config.bestofn_max_tokens,
+ "temperature": cepo_config.bestofn_temperature
+ }
+
+ rating_response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
completion_tokens += rating_response.usage.completion_tokens
rating_response = rating_response.choices[0].message.content.strip()
@@ -359,7 +412,7 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
return completions[best_index], completion_tokens, cb_log
-def rate_completions_pairwise(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict) -> tuple[str, int, dict]:
+def rate_completions_pairwise(system_prompt: str, initial_query: str, client: Any, model: str, completions: list[str], cepo_config: CepoConfig, cb_log: dict, request_id: str = None) -> tuple[str, int, dict]:
"""
Rates completions for the Best of N step of CePO. Completions are rated pairwise against each other in both orders (A vs B and B vs A).
@@ -405,12 +458,20 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
"If the second response is better, reply with \"Better Response: [[1]]\"."
rating_messages.append({"role": "system", "content": content})
- rating_response = client.chat.completions.create(
- model=model,
- messages=rating_messages,
- max_tokens=cepo_config.bestofn_max_tokens,
- temperature=cepo_config.bestofn_temperature
- )
+ # Prepare request for logging
+ provider_request = {
+ "model": model,
+ "messages": rating_messages,
+ "max_tokens": cepo_config.bestofn_max_tokens,
+ "temperature": cepo_config.bestofn_temperature
+ }
+
+ rating_response = client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
completion_tokens += rating_response.usage.completion_tokens
rating_response = rating_response.choices[0].message.content.strip()
@@ -440,7 +501,7 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
return completions[best_index], completion_tokens, cb_log
-def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_config: CepoConfig) -> tuple[str, int]:
+def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_config: CepoConfig, request_id: str = None) -> tuple[str, int]:
"""
Applies CePO reasoning flow for the given task. First, it generates multiple completions, and then rates them to select the best one.
Each completion is generated as follows:
@@ -477,6 +538,6 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c
raise ValueError("Invalid rating type in cepo_config")
rating_model = cepo_config.rating_model if cepo_config.rating_model else model
- best_completion, completion_tokens_rating, cb_log = rate_completions_fn(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log)
+ best_completion, completion_tokens_rating, cb_log = rate_completions_fn(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id)
return best_completion, completion_tokens_planning + completion_tokens_rating
diff --git a/optillm/conversation_logger.py b/optillm/conversation_logger.py
new file mode 100644
index 00000000..6c77cd3b
--- /dev/null
+++ b/optillm/conversation_logger.py
@@ -0,0 +1,265 @@
+import json
+import logging
+import threading
+import uuid
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Dict, Any, Optional, List
+from dataclasses import dataclass, field
+import time
+
+logger = logging.getLogger(__name__)
+
+# Global logger instance - will be set by optillm.py
+_global_logger: Optional['ConversationLogger'] = None
+
+@dataclass
+class ConversationEntry:
+ """Represents a single conversation entry being logged"""
+ request_id: str
+ timestamp: str
+ approach: str
+ model: str
+ client_request: Dict[str, Any]
+ provider_calls: List[Dict[str, Any]] = field(default_factory=list)
+ final_response: Optional[Dict[str, Any]] = None
+ total_duration_ms: Optional[int] = None
+ error: Optional[str] = None
+ start_time: float = field(default_factory=time.time)
+
+class ConversationLogger:
+ """
+ Logger for OptiLLM conversations including all provider interactions and metadata.
+
+ Logs are saved in JSONL format (one JSON object per line) with daily rotation.
+ Each entry contains the full conversation including all intermediate provider calls.
+ """
+
+ def __init__(self, log_dir: Path, enabled: bool = False):
+ self.enabled = enabled
+ self.log_dir = log_dir
+ self.active_entries: Dict[str, ConversationEntry] = {}
+ self._lock = threading.Lock()
+
+ if self.enabled:
+ self.log_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Conversation logging enabled. Logs will be saved to: {self.log_dir}")
+ else:
+ logger.debug("Conversation logging disabled")
+
+ def _get_log_file_path(self, timestamp: datetime = None) -> Path:
+ """Get the log file path for a given timestamp (defaults to now)"""
+ if timestamp is None:
+ timestamp = datetime.now(timezone.utc)
+ date_str = timestamp.strftime("%Y-%m-%d")
+ return self.log_dir / f"conversations_{date_str}.jsonl"
+
+ def _generate_request_id(self) -> str:
+ """Generate a unique request ID"""
+ return f"req_{uuid.uuid4().hex[:8]}"
+
+ def start_conversation(self,
+ client_request: Dict[str, Any],
+ approach: str,
+ model: str) -> str:
+ """
+ Start logging a new conversation.
+
+ Args:
+ client_request: The original request from the client
+ approach: The optimization approach being used
+ model: The model name
+
+ Returns:
+ str: Unique request ID for this conversation
+ """
+ if not self.enabled:
+ return ""
+
+ request_id = self._generate_request_id()
+ timestamp = datetime.now(timezone.utc).isoformat()
+
+ entry = ConversationEntry(
+ request_id=request_id,
+ timestamp=timestamp,
+ approach=approach,
+ model=model,
+ client_request=client_request.copy()
+ )
+
+ with self._lock:
+ self.active_entries[request_id] = entry
+
+ logger.debug(f"Started conversation logging for request {request_id}")
+ return request_id
+
+ def log_provider_call(self,
+ request_id: str,
+ provider_request: Dict[str, Any],
+ provider_response: Dict[str, Any]) -> None:
+ """
+ Log a provider API call and response.
+
+ Args:
+ request_id: The request ID for this conversation
+ provider_request: The request sent to the provider
+ provider_response: The response received from the provider
+ """
+ if not self.enabled or not request_id:
+ return
+
+ with self._lock:
+ entry = self.active_entries.get(request_id)
+ if not entry:
+ logger.warning(f"No active conversation found for request {request_id}")
+ return
+
+ call_data = {
+ "call_number": len(entry.provider_calls) + 1,
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "request": provider_request.copy(),
+ "response": provider_response.copy()
+ }
+
+ entry.provider_calls.append(call_data)
+
+ logger.debug(f"Logged provider call #{len(entry.provider_calls)} for request {request_id}")
+
+ def log_final_response(self,
+ request_id: str,
+ final_response: Dict[str, Any]) -> None:
+ """
+ Log the final response sent back to the client.
+
+ Args:
+ request_id: The request ID for this conversation
+ final_response: The final response sent to the client
+ """
+ if not self.enabled or not request_id:
+ return
+
+ with self._lock:
+ entry = self.active_entries.get(request_id)
+ if not entry:
+ logger.warning(f"No active conversation found for request {request_id}")
+ return
+
+ entry.final_response = final_response.copy()
+ entry.final_response["timestamp"] = datetime.now(timezone.utc).isoformat()
+
+ def log_error(self, request_id: str, error: str) -> None:
+ """
+ Log an error for this conversation.
+
+ Args:
+ request_id: The request ID for this conversation
+ error: Error message or description
+ """
+ if not self.enabled or not request_id:
+ return
+
+ with self._lock:
+ entry = self.active_entries.get(request_id)
+ if not entry:
+ logger.warning(f"No active conversation found for request {request_id}")
+ return
+
+ entry.error = error
+
+ logger.debug(f"Logged error for request {request_id}: {error}")
+
+ def finalize_conversation(self, request_id: str) -> None:
+ """
+ Finalize and save the conversation to disk.
+
+ Args:
+ request_id: The request ID for this conversation
+ """
+ if not self.enabled or not request_id:
+ return
+
+ with self._lock:
+ entry = self.active_entries.pop(request_id, None)
+ if not entry:
+ logger.warning(f"No active conversation found for request {request_id}")
+ return
+
+ # Calculate total duration
+ entry.total_duration_ms = int((time.time() - entry.start_time) * 1000)
+
+ # Convert to dict for JSON serialization
+ log_entry = {
+ "timestamp": entry.timestamp,
+ "request_id": entry.request_id,
+ "approach": entry.approach,
+ "model": entry.model,
+ "client_request": entry.client_request,
+ "provider_calls": entry.provider_calls,
+ "final_response": entry.final_response,
+ "total_duration_ms": entry.total_duration_ms,
+ "error": entry.error
+ }
+
+ # Write to log file
+ self._write_log_entry(log_entry)
+
+ logger.debug(f"Finalized conversation for request {request_id}")
+
+ def _write_log_entry(self, log_entry: Dict[str, Any]) -> None:
+ """Write a log entry to the appropriate JSONL file"""
+ try:
+ log_file_path = self._get_log_file_path()
+ with open(log_file_path, 'a', encoding='utf-8') as f:
+ json.dump(log_entry, f, separators=(',', ':'))
+ f.write('\n')
+ logger.debug(f"Wrote log entry to {log_file_path}")
+ except Exception as e:
+ logger.error(f"Failed to write log entry: {e}")
+
+ def get_stats(self) -> Dict[str, Any]:
+ """Get statistics about conversation logging"""
+ with self._lock:
+ active_count = len(self.active_entries)
+
+ stats = {
+ "enabled": self.enabled,
+ "log_dir": str(self.log_dir),
+ "active_conversations": active_count
+ }
+
+ if self.enabled:
+ # Count total log files and approximate total entries
+ log_files = list(self.log_dir.glob("conversations_*.jsonl"))
+ total_entries = 0
+ for log_file in log_files:
+ try:
+ with open(log_file, 'r', encoding='utf-8') as f:
+ total_entries += sum(1 for line in f if line.strip())
+ except Exception:
+ pass
+
+ stats.update({
+ "log_files_count": len(log_files),
+ "total_entries_approximate": total_entries
+ })
+
+ return stats
+
+
+# Module-level functions for easy access from approach modules
+def set_global_logger(logger_instance: 'ConversationLogger') -> None:
+ """Set the global logger instance (called by optillm.py)"""
+ global _global_logger
+ _global_logger = logger_instance
+
+
+def log_provider_call(request_id: str, provider_request: Dict[str, Any], provider_response: Dict[str, Any]) -> None:
+ """Log a provider call using the global logger instance"""
+ if _global_logger and _global_logger.enabled:
+ _global_logger.log_provider_call(request_id, provider_request, provider_response)
+
+
+def log_error(request_id: str, error_message: str) -> None:
+ """Log an error using the global logger instance"""
+ if _global_logger and _global_logger.enabled:
+ _global_logger.log_error(request_id, error_message)
\ No newline at end of file
diff --git a/optillm/cot_reflection.py b/optillm/cot_reflection.py
index dfc6efdb..4596f6fa 100644
--- a/optillm/cot_reflection.py
+++ b/optillm/cot_reflection.py
@@ -1,9 +1,11 @@
import re
import logging
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
-def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False, request_config: dict = None):
+def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False, request_config: dict = None, request_id: str = None):
cot_completion_tokens = 0
# Extract temperature and max_tokens from request_config with defaults
@@ -41,15 +43,21 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
"""
# Make the API call using user-provided or default parameters
- response = client.chat.completions.create(
- model=model,
- messages=[
+ provider_request = {
+ "model": model,
+ "messages": [
{"role": "system", "content": cot_prompt},
{"role": "user", "content": initial_query}
],
- temperature=temperature,
- max_tokens=max_tokens
- )
+ "temperature": temperature,
+ "max_tokens": max_tokens
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
# Extract the full response
full_response = response.choices[0].message.content
diff --git a/optillm/leap.py b/optillm/leap.py
index 2d4beb09..5f212d2c 100644
--- a/optillm/leap.py
+++ b/optillm/leap.py
@@ -2,16 +2,19 @@
import re
from typing import List, Tuple
import json
+import optillm
+from optillm import conversation_logger
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class LEAP:
- def __init__(self, system_prompt: str, client, model: str):
+ def __init__(self, system_prompt: str, client, model: str, request_id: str = None):
self.system_prompt = system_prompt
self.client = client
self.model = model
+ self.request_id = request_id
self.low_level_principles = []
self.high_level_principles = []
self.leap_completion_tokens = 0
@@ -22,10 +25,12 @@ def extract_output(self, text: str) -> str:
def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str]]:
logger.info("Extracting examples from initial query")
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Analyze the following query and determine if it contains few-shot examples.
@@ -46,7 +51,15 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str
Query: {initial_query}
"""}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.leap_completion_tokens += response.usage.completion_tokens
examples_str = self.extract_output(response.choices[0].message.content)
logger.debug(f"Extracted examples: {examples_str}")
@@ -67,10 +80,11 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
logger.info("Generating mistakes for given examples")
mistakes = []
for question, correct_answer in examples:
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Instruction: Answer the following question step by step. To induce a mistake,
@@ -80,8 +94,15 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
Think step by step, but make sure to include a mistake.
"""}
],
- temperature=0.7,
- )
+ "temperature": 0.7,
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.leap_completion_tokens += response.usage.completion_tokens
generated_reasoning = response.choices[0].message.content
generated_answer = self.extract_output(generated_reasoning)
@@ -92,10 +113,11 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]]) -> List[str]:
logger.info("Generating low-level principles from mistakes")
for question, generated_reasoning, generated_answer, correct_answer in mistakes:
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Question: {question}
@@ -112,7 +134,14 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
Insights: Enclose ONLY the principles or insights within tags.
"""}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.leap_completion_tokens += response.usage.completion_tokens
self.low_level_principles.append(self.extract_output(response.choices[0].message.content))
return self.low_level_principles
@@ -120,10 +149,11 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
def generate_high_level_principles(self) -> List[str]:
logger.info("Generating high-level principles from low-level principles")
principles_text = "\n".join(self.low_level_principles)
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Low-level principles: {principles_text}
@@ -137,7 +167,14 @@ def generate_high_level_principles(self) -> List[str]:
Enclose your list of principles within tags.
"""}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.leap_completion_tokens += response.usage.completion_tokens
self.high_level_principles = self.extract_output(response.choices[0].message.content).split("\n")
return self.high_level_principles
@@ -145,10 +182,11 @@ def generate_high_level_principles(self) -> List[str]:
def apply_principles(self, query: str) -> str:
logger.info("Applying learned principles to query")
principles_text = "\n".join(self.high_level_principles)
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Please answer the following query. Keep in mind these principles:
@@ -158,7 +196,14 @@ def apply_principles(self, query: str) -> str:
Query: {query}
"""}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.leap_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content
@@ -175,6 +220,6 @@ def solve(self, initial_query: str) -> str:
return self.apply_principles(initial_query)
-def leap(system_prompt: str, initial_query: str, client, model: str) -> str:
- leap_solver = LEAP(system_prompt, client, model)
+def leap(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
+ leap_solver = LEAP(system_prompt, client, model, request_id)
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens
\ No newline at end of file
diff --git a/optillm/mcts.py b/optillm/mcts.py
index 2c55c955..d1727ea2 100644
--- a/optillm/mcts.py
+++ b/optillm/mcts.py
@@ -3,6 +3,8 @@
import numpy as np
import networkx as nx
from typing import List, Dict
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
@@ -24,7 +26,7 @@ def __init__(self, state: DialogueState, parent=None):
self.value = 0
class MCTS:
- def __init__(self, simulation_depth, exploration_weight, client, model):
+ def __init__(self, simulation_depth, exploration_weight, client, model, request_id=None):
self.simulation_depth = simulation_depth
self.exploration_weight = exploration_weight
self.root = None
@@ -33,6 +35,7 @@ def __init__(self, simulation_depth, exploration_weight, client, model):
self.client = client
self.model = model
self.completion_tokens = 0
+ self.request_id = request_id
def select(self, node: MCTSNode) -> MCTSNode:
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
@@ -111,13 +114,20 @@ def generate_actions(self, state: DialogueState) -> List[str]:
n = 3
logger.info(f"Requesting {n} completions from the model")
- response = self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- max_tokens=4096,
- n=n,
- temperature=1
- )
+ provider_request = {
+ "model": self.model,
+ "messages": messages,
+ "max_tokens": 4096,
+ "n": n,
+ "temperature": 1
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
completions = [choice.message.content.strip() for choice in response.choices]
self.completion_tokens += response.usage.completion_tokens
logger.info(f"Received {len(completions)} completions from the model")
@@ -133,13 +143,19 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
messages.append({"role": "user", "content": "Based on this conversation, what might the user ask or say next? Provide a likely user query."})
logger.info("Requesting next user query from the model")
- response = self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- max_tokens=1024,
- n=1,
- temperature=1
- )
+ provider_request = {
+ "model": self.model,
+ "messages": messages,
+ "max_tokens": 1024,
+ "n": 1,
+ "temperature": 1
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
next_query = response.choices[0].message.content
self.completion_tokens += response.usage.completion_tokens
@@ -157,13 +173,20 @@ def evaluate_state(self, state: DialogueState) -> float:
messages.extend(state.conversation_history)
messages.append({"role": "user", "content": "Evaluate the quality of this conversation on a scale from 0 to 1, where 0 is poor and 1 is excellent. Consider factors such as coherence, relevance, and engagement. Respond with only a number."})
- response = self.client.chat.completions.create(
- model=self.model,
- messages=messages,
- max_tokens=256,
- n=1,
- temperature=0.1
- )
+ provider_request = {
+ "model": self.model,
+ "messages": messages,
+ "max_tokens": 256,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.completion_tokens += response.usage.completion_tokens
try:
score = float(response.choices[0].message.content.strip())
@@ -175,10 +198,10 @@ def evaluate_state(self, state: DialogueState) -> float:
return 0.5 # Default to a neutral score if parsing fails
def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2,
- simulation_depth: int = 1) -> str:
+ simulation_depth: int = 1, request_id: str = None) -> str:
logger.info("Starting chat with MCTS")
logger.info(f"Parameters: num_simulations={num_simulations}, exploration_weight={exploration_weight}, simulation_depth={simulation_depth}")
- mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model)
+ mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_id=request_id)
initial_state = DialogueState(system_prompt, [], initial_query)
logger.info(f"Initial query: {initial_query}")
final_state = mcts.search(initial_state, num_simulations)
diff --git a/optillm/moa.py b/optillm/moa.py
index 21d5e105..9f6fd034 100644
--- a/optillm/moa.py
+++ b/optillm/moa.py
@@ -1,8 +1,10 @@
import logging
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
-def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str) -> str:
+def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
logger.info(f"Starting mixture_of_agents function with model: {model}")
moa_completion_tokens = 0
completions = []
@@ -11,16 +13,26 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
try:
# Try to generate 3 completions in a single API call using n parameter
- response = client.chat.completions.create(
- model=model,
- messages=[
+ provider_request = {
+ "model": model,
+ "messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}
],
- max_tokens=4096,
- n=3,
- temperature=1
- )
+ "max_tokens": 4096,
+ "n": 3,
+ "temperature": 1
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Convert response to dict for logging
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+
+ # Log provider call if conversation logging is enabled
+ if request_id:
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
completions = [choice.message.content for choice in response.choices]
moa_completion_tokens += response.usage.completion_tokens
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
@@ -33,15 +45,25 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
completions = []
for i in range(3):
try:
- response = client.chat.completions.create(
- model=model,
- messages=[
+ provider_request = {
+ "model": model,
+ "messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}
],
- max_tokens=4096,
- temperature=1
- )
+ "max_tokens": 4096,
+ "temperature": 1
+ }
+
+ response = client.chat.completions.create(**provider_request)
+
+ # Convert response to dict for logging
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+
+ # Log provider call if conversation logging is enabled
+ if request_id:
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
completions.append(response.choices[0].message.content)
moa_completion_tokens += response.usage.completion_tokens
logger.debug(f"Generated completion {i+1}/3")
@@ -83,16 +105,27 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
"""
logger.debug("Generating critiques")
- critique_response = client.chat.completions.create(
- model=model,
- messages=[
+
+ provider_request = {
+ "model": model,
+ "messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": critique_prompt}
],
- max_tokens=512,
- n=1,
- temperature=0.1
- )
+ "max_tokens": 512,
+ "n": 1,
+ "temperature": 0.1
+ }
+
+ critique_response = client.chat.completions.create(**provider_request)
+
+ # Convert response to dict for logging
+ response_dict = critique_response.model_dump() if hasattr(critique_response, 'model_dump') else critique_response
+
+ # Log provider call if conversation logging is enabled
+ if request_id:
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
critiques = critique_response.choices[0].message.content
moa_completion_tokens += critique_response.usage.completion_tokens
logger.info(f"Generated critiques. Tokens used: {critique_response.usage.completion_tokens}")
@@ -119,16 +152,27 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
"""
logger.debug("Generating final response")
- final_response = client.chat.completions.create(
- model=model,
- messages=[
+
+ provider_request = {
+ "model": model,
+ "messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": final_prompt}
],
- max_tokens=8192,
- n=1,
- temperature=0.1
- )
+ "max_tokens": 8192,
+ "n": 1,
+ "temperature": 0.1
+ }
+
+ final_response = client.chat.completions.create(**provider_request)
+
+ # Convert response to dict for logging
+ response_dict = final_response.model_dump() if hasattr(final_response, 'model_dump') else final_response
+
+ # Log provider call if conversation logging is enabled
+ if request_id:
+ conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
moa_completion_tokens += final_response.usage.completion_tokens
logger.info(f"Generated final response. Tokens used: {final_response.usage.completion_tokens}")
diff --git a/optillm/plansearch.py b/optillm/plansearch.py
index 64632dc1..517bccc9 100644
--- a/optillm/plansearch.py
+++ b/optillm/plansearch.py
@@ -1,13 +1,16 @@
import logging
from typing import List, Tuple
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
class PlanSearch:
- def __init__(self, system_prompt: str, client, model: str):
+ def __init__(self, system_prompt: str, client, model: str, request_id: str = None):
self.system_prompt = system_prompt
self.client = client
self.model = model
+ self.request_id = request_id
self.plansearch_completion_tokens = 0
def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
@@ -21,14 +24,22 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List
Please provide {num_observations} observations."""
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.plansearch_completion_tokens += response.usage.completion_tokens
observations = response.choices[0].message.content.strip().split('\n')
return [obs.strip() for obs in observations if obs.strip()]
@@ -48,14 +59,22 @@ def generate_derived_observations(self, problem: str, observations: List[str], n
Please provide {num_new_observations} new observations derived from the existing ones."""
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.plansearch_completion_tokens += response.usage.completion_tokens
new_observations = response.choices[0].message.content.strip().split('\n')
return [obs.strip() for obs in new_observations if obs.strip()]
@@ -73,14 +92,22 @@ def generate_solution(self, problem: str, observations: List[str]) -> str:
Quote relevant parts of the observations EXACTLY before each step of the solution. QUOTING
IS CRUCIAL."""
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.plansearch_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content.strip()
@@ -98,14 +125,22 @@ def implement_solution(self, problem: str, solution: str) -> str:
Please implement the solution in Python."""
- response = self.client.chat.completions.create(
- model=self.model,
- max_tokens=4096,
- messages=[
+ # Prepare request for logging
+ provider_request = {
+ "model": self.model,
+ "max_tokens": 4096,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
]
- )
+ }
+
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call if conversation logging is enabled
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
self.plansearch_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content.strip()
@@ -133,6 +168,6 @@ def solve_multiple(self, problem: str, n: int, num_initial_observations: int = 3
solutions.append(python_implementation)
return solutions
-def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1) -> List[str]:
- planner = PlanSearch(system_prompt, client, model)
+def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_id: str = None) -> List[str]:
+ planner = PlanSearch(system_prompt, client, model, request_id)
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens
diff --git a/optillm/pvg.py b/optillm/pvg.py
index 8417f1e9..44c2b27d 100644
--- a/optillm/pvg.py
+++ b/optillm/pvg.py
@@ -1,12 +1,14 @@
import logging
import re
from typing import List, Tuple
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
pvg_completion_tokens = 0
-def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7) -> List[str]:
+def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, request_id: str = None) -> List[str]:
global pvg_completion_tokens
role = "sneaky" if is_sneaky else "helpful"
logger.info(f"Generating {num_solutions} {role} solutions")
@@ -30,19 +32,26 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
{"role": "system", "content": f"{system_prompt}\n{role_instruction}\nYou are in {role} mode."},
{"role": "user", "content": query}
]
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- n=num_solutions,
- max_tokens=4096,
- temperature=temperature,
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "n": num_solutions,
+ "max_tokens": 4096,
+ "temperature": temperature,
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
pvg_completion_tokens += response.usage.completion_tokens
solutions = [choice.message.content for choice in response.choices]
logger.debug(f"Generated {role} solutions: {solutions}")
return solutions
-def verify_solutions(client, system_prompt: str, initial_query: str, solutions: List[str], model: str) -> List[float]:
+def verify_solutions(client, system_prompt: str, initial_query: str, solutions: List[str], model: str, request_id: str = None) -> List[float]:
global pvg_completion_tokens
logger.info(f"Verifying {len(solutions)} solutions")
verify_prompt = f"""{system_prompt}
@@ -74,12 +83,19 @@ def verify_solutions(client, system_prompt: str, initial_query: str, solutions:
{"role": "system", "content": verify_prompt},
{"role": "user", "content": f"Problem: {initial_query}\n\nSolution: {solution}"}
]
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=1024,
- temperature=0.2,
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 1024,
+ "temperature": 0.2,
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
pvg_completion_tokens += response.usage.completion_tokens
rating = response.choices[0].message.content
logger.debug(f"Raw rating for solution {i+1}: {rating}")
@@ -135,7 +151,7 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
logger.warning("No answer found in the state.")
return "", 0.0
-def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3) -> str:
+def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_id: str = None) -> str:
global pvg_completion_tokens
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")
@@ -147,11 +163,11 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
temperature = max(0.2, 0.7 - (round * 0.1))
- helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature)
- sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature)
+ helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, request_id=request_id)
+ sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, request_id=request_id)
all_solutions = helpful_solutions + sneaky_solutions
- scores = verify_solutions(client, system_prompt, initial_query, all_solutions, model)
+ scores = verify_solutions(client, system_prompt, initial_query, all_solutions, model, request_id=request_id)
round_best_solution = max(zip(all_solutions, scores), key=lambda x: x[1])
@@ -179,12 +195,19 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
{"role": "system", "content": system_prompt},
{"role": "user", "content": refine_prompt}
]
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=1024,
- temperature=0.5,
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 1024,
+ "temperature": 0.5,
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
pvg_completion_tokens += response.usage.completion_tokens
initial_query = response.choices[0].message.content
logger.debug(f"Refined query: {initial_query}")
diff --git a/optillm/reread.py b/optillm/reread.py
index 5b57d8b7..32706a3d 100644
--- a/optillm/reread.py
+++ b/optillm/reread.py
@@ -1,8 +1,10 @@
import logging
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
-def re2_approach(system_prompt, initial_query, client, model, n=1):
+def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: str = None):
"""
Implement the RE2 (Re-Reading) approach for improved reasoning in LLMs.
@@ -28,11 +30,18 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
]
try:
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- n=n
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "n": n
+ }
+ response = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
re2_completion_tokens += response.usage.completion_tokens
if n == 1:
return response.choices[0].message.content.strip(), re2_completion_tokens
diff --git a/optillm/rstar.py b/optillm/rstar.py
index aaca2e92..5cbeda6c 100644
--- a/optillm/rstar.py
+++ b/optillm/rstar.py
@@ -6,6 +6,8 @@
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor
+import optillm
+from optillm import conversation_logger
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -21,7 +23,7 @@ def __init__(self, state: str, action: str, parent: 'Node' = None):
self.value = 0.0
class RStar:
- def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4):
+ def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4, request_id: str = None):
self.client = client
self.model_name = model
self.max_depth = max_depth
@@ -31,6 +33,7 @@ def __init__(self, system: str, client, model: str, max_depth: int = 3, num_roll
self.original_question = None
self.system = system
self.rstar_completion_tokens = 0
+ self.request_id = request_id
logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}")
async def generate_response_async(self, prompt: str) -> str:
@@ -93,15 +96,22 @@ async def solve_async(self, question: str) -> str:
def generate_response(self, prompt: str) -> str:
logger.debug(f"Generating response for prompt: {prompt[:100]}...")
- response = self.client.chat.completions.create(
- model=self.model_name,
- messages=[
+ provider_request = {
+ "model": self.model_name,
+ "messages": [
{"role": "system", "content": "You are a helpful assistant focused on solving mathematical problems. Stick to the given question and avoid introducing new scenarios."},
{"role": "user", "content": prompt}
],
- max_tokens=4096,
- temperature=0.2
- )
+ "max_tokens": 4096,
+ "temperature": 0.2
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.rstar_completion_tokens += response.usage.completion_tokens
generated_response = response.choices[0].message.content.strip()
logger.debug(f"Generated response: {generated_response}")
diff --git a/optillm/rto.py b/optillm/rto.py
index 61e70e5f..59ca88d6 100644
--- a/optillm/rto.py
+++ b/optillm/rto.py
@@ -1,5 +1,7 @@
import re
import logging
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
@@ -13,45 +15,66 @@ def extract_code_from_prompt(text):
logger.warning("Could not extract code from prompt. Returning original text.")
return text
-def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str) -> str:
+def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
rto_completion_tokens = 0
messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}]
# Generate initial code (C1)
- response_c1 = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=4096,
- n=1,
- temperature=0.1
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 4096,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response_c1 = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response_c1.model_dump() if hasattr(response_c1, 'model_dump') else response_c1
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
c1 = response_c1.choices[0].message.content
rto_completion_tokens += response_c1.usage.completion_tokens
# Generate description of the code (Q2)
messages.append({"role": "assistant", "content": c1})
messages.append({"role": "user", "content": "Summarize or describe the code you just created. The summary should be in form of an instruction such that, given the instruction you can create the code yourself."})
- response_q2 = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=1024,
- n=1,
- temperature=0.1
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 1024,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response_q2 = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response_q2.model_dump() if hasattr(response_q2, 'model_dump') else response_q2
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
q2 = response_q2.choices[0].message.content
rto_completion_tokens += response_q2.usage.completion_tokens
# Generate second code based on the description (C2)
messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": q2}]
- response_c2 = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=4096,
- n=1,
- temperature=0.1
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 4096,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response_c2 = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response_c2.model_dump() if hasattr(response_c2, 'model_dump') else response_c2
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
c2 = response_c2.choices[0].message.content
rto_completion_tokens += response_c2.usage.completion_tokens
@@ -63,13 +86,20 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Initial query: {initial_query}\n\nFirst generated code (C1):\n{c1}\n\nSecond generated code (C2):\n{c2}\n\nBased on the initial query and these two different code implementations, generate a final, optimized version of the code. Only respond with the final code, do not return anything else."}]
- response_c3 = client.chat.completions.create(
- model=model,
- messages=messages,
- max_tokens=4096,
- n=1,
- temperature=0.1
- )
+ provider_request = {
+ "model": model,
+ "messages": messages,
+ "max_tokens": 4096,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response_c3 = client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
+ response_dict = response_c3.model_dump() if hasattr(response_c3, 'model_dump') else response_c3
+ optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
+
c3 = response_c3.choices[0].message.content
rto_completion_tokens += response_c3.usage.completion_tokens
diff --git a/optillm/self_consistency.py b/optillm/self_consistency.py
index 3b33bb73..441599bc 100644
--- a/optillm/self_consistency.py
+++ b/optillm/self_consistency.py
@@ -1,29 +1,39 @@
import logging
from typing import List, Dict
from difflib import SequenceMatcher
+import optillm
+from optillm import conversation_logger
logger = logging.getLogger(__name__)
class AdvancedSelfConsistency:
- def __init__(self, client, model: str, num_samples: int = 5, similarity_threshold: float = 0.8):
+ def __init__(self, client, model: str, num_samples: int = 5, similarity_threshold: float = 0.8, request_id: str = None):
self.client = client
self.model = model
self.num_samples = num_samples
self.similarity_threshold = similarity_threshold
self.self_consistency_completion_tokens = 0
+ self.request_id = request_id
def generate_responses(self, system_prompt: str, user_prompt: str) -> List[str]:
responses = []
for _ in range(self.num_samples):
- response = self.client.chat.completions.create(
- model=self.model,
- messages=[
+ provider_request = {
+ "model": self.model,
+ "messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
- temperature=1,
- max_tokens=4096
- )
+ "temperature": 1,
+ "max_tokens": 4096
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.self_consistency_completion_tokens += response.usage.completion_tokens
responses.append(response.choices[0].message.content)
return responses
@@ -73,8 +83,8 @@ def evaluate(self, system_prompt: str, user_prompt: str) -> Dict[str, any]:
"aggregated_result": aggregated_result
}
-def advanced_self_consistency_approach(system_prompt: str, initial_query: str, client, model: str) -> str:
- self_consistency = AdvancedSelfConsistency(client, model)
+def advanced_self_consistency_approach(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
+ self_consistency = AdvancedSelfConsistency(client, model, request_id=request_id)
result = self_consistency.evaluate(system_prompt, initial_query)
logger.info("Advanced Self-Consistency Results:")
diff --git a/optillm/z3_solver.py b/optillm/z3_solver.py
index 976ec92f..dcc83d1b 100644
--- a/optillm/z3_solver.py
+++ b/optillm/z3_solver.py
@@ -9,6 +9,8 @@
import math
import multiprocessing
import traceback
+import optillm
+from optillm import conversation_logger
class TimeoutException(Exception):
pass
@@ -131,12 +133,13 @@ def Rational(numerator, denominator=1):
return ("success", output_buffer.getvalue())
class Z3SymPySolverSystem:
- def __init__(self, system_prompt: str, client, model: str, timeout: int = 30):
+ def __init__(self, system_prompt: str, client, model: str, timeout: int = 30, request_id: str = None):
self.system_prompt = system_prompt
self.model = model
self.client = client
self.timeout = timeout
self.solver_completion_tokens = 0
+ self.request_id = request_id
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def process_query(self, query: str) -> str:
@@ -177,16 +180,23 @@ def analyze_query(self, query: str) -> str:
[Your step-by-step analysis]
"""
- analysis_response = self.client.chat.completions.create(
- model=self.model,
- messages=[
+ provider_request = {
+ "model": self.model,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": analysis_prompt}
],
- max_tokens=1024,
- n=1,
- temperature=0.1
- )
+ "max_tokens": 1024,
+ "n": 1,
+ "temperature": 0.1
+ }
+ analysis_response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = analysis_response.model_dump() if hasattr(analysis_response, 'model_dump') else analysis_response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.solver_completion_tokens = analysis_response.usage.completion_tokens
return analysis_response.choices[0].message.content
@@ -205,30 +215,44 @@ def generate_response(self, query: str, analysis: str, solver_result: Dict[str,
Response:
"""
- response = self.client.chat.completions.create(
- model=self.model,
- messages=[
+ provider_request = {
+ "model": self.model,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": response_prompt}
],
- max_tokens=4096,
- n=1,
- temperature=0.1
- )
+ "max_tokens": 4096,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.solver_completion_tokens = response.usage.completion_tokens
return response.choices[0].message.content
def standard_llm_inference(self, query: str) -> str:
- response = self.client.chat.completions.create(
- model=self.model,
- messages=[
+ provider_request = {
+ "model": self.model,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": query}
],
- max_tokens=4096,
- n=1,
- temperature=0.1
- )
+ "max_tokens": 4096,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.solver_completion_tokens = response.usage.completion_tokens
return response.choices[0].message.content
@@ -265,16 +289,23 @@ def solve_with_z3_sympy(self, formulation: str, max_attempts: int = 3) -> Dict[s
# Corrected code here
```
"""
- response = self.client.chat.completions.create(
- model=self.model,
- messages=[
+ provider_request = {
+ "model": self.model,
+ "messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": error_prompt}
],
- max_tokens=1024,
- n=1,
- temperature=0.1
- )
+ "max_tokens": 1024,
+ "n": 1,
+ "temperature": 0.1
+ }
+ response = self.client.chat.completions.create(**provider_request)
+
+ # Log provider call
+ if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
+ response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
+ optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
+
self.solver_completion_tokens = response.usage.completion_tokens
formulation = self.extract_and_validate_expressions(response.choices[0].message.content)
diff --git a/pyproject.toml b/pyproject.toml
index 395be3b4..4d68a74f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "optillm"
-version = "0.1.29"
+version = "0.2.0"
description = "An optimizing inference proxy for LLMs."
readme = "README.md"
license = "Apache-2.0"
diff --git a/tests/test_conversation_logger.py b/tests/test_conversation_logger.py
new file mode 100644
index 00000000..b3a19b72
--- /dev/null
+++ b/tests/test_conversation_logger.py
@@ -0,0 +1,212 @@
+import json
+import tempfile
+import unittest
+from datetime import datetime, timezone
+from pathlib import Path
+from unittest.mock import patch
+
+import sys
+sys.path.append('..')
+from optillm.conversation_logger import ConversationLogger, ConversationEntry
+
+
+class TestConversationLogger(unittest.TestCase):
+ def setUp(self):
+ """Set up test fixtures"""
+ self.temp_dir = Path(tempfile.mkdtemp())
+ self.logger_enabled = ConversationLogger(self.temp_dir, enabled=True)
+ self.logger_disabled = ConversationLogger(self.temp_dir, enabled=False)
+
+ def tearDown(self):
+ """Clean up test fixtures"""
+ # Clean up temp directory
+ for file in self.temp_dir.glob("*"):
+ file.unlink()
+ self.temp_dir.rmdir()
+
+ def test_logger_initialization_and_disabled_state(self):
+ """Test logger initialization and disabled logger behavior"""
+ # Test enabled logger
+ self.assertTrue(self.logger_enabled.enabled)
+ self.assertEqual(self.logger_enabled.log_dir, self.temp_dir)
+ self.assertTrue(self.temp_dir.exists())
+
+ # Test disabled logger
+ self.assertFalse(self.logger_disabled.enabled)
+
+ # Disabled logger should return empty string and perform no operations
+ request_id = self.logger_disabled.start_conversation({}, "test", "model")
+ self.assertEqual(request_id, "")
+
+ # Other methods should not raise errors but do nothing
+ self.logger_disabled.log_provider_call("req1", {}, {})
+ self.logger_disabled.log_final_response("req1", {})
+ self.logger_disabled.log_error("req1", "error")
+ self.logger_disabled.finalize_conversation("req1")
+
+ def test_conversation_lifecycle(self):
+ """Test complete conversation lifecycle: start, log calls, errors, finalize"""
+ client_request = {
+ "messages": [{"role": "user", "content": "Hello"}],
+ "model": "gpt-4o-mini",
+ "temperature": 0.7
+ }
+
+ # Start conversation
+ request_id = self.logger_enabled.start_conversation(
+ client_request=client_request,
+ approach="moa",
+ model="gpt-4o-mini"
+ )
+
+ # Should return a valid request ID
+ self.assertIsInstance(request_id, str)
+ self.assertTrue(request_id.startswith("req_"))
+ self.assertEqual(len(request_id), 12) # "req_" + 8 hex chars
+
+ # Should create an active entry
+ self.assertIn(request_id, self.logger_enabled.active_entries)
+ entry = self.logger_enabled.active_entries[request_id]
+ self.assertEqual(entry.request_id, request_id)
+ self.assertEqual(entry.approach, "moa")
+ self.assertEqual(entry.model, "gpt-4o-mini")
+
+ # Log multiple provider calls
+ provider_request = {"model": "test", "messages": []}
+ provider_response = {"choices": [{"message": {"content": "response"}}]}
+
+ self.logger_enabled.log_provider_call(request_id, provider_request, provider_response)
+ self.logger_enabled.log_provider_call(request_id, provider_request, provider_response)
+
+ # Check calls were logged
+ entry = self.logger_enabled.active_entries[request_id]
+ self.assertEqual(len(entry.provider_calls), 2)
+ self.assertEqual(entry.provider_calls[0]["call_number"], 1)
+ self.assertEqual(entry.provider_calls[1]["call_number"], 2)
+
+ # Log final response
+ final_response = {"choices": [{"message": {"content": "final"}}]}
+ self.logger_enabled.log_final_response(request_id, final_response)
+
+ # Log error
+ error_msg = "Test error message"
+ self.logger_enabled.log_error(request_id, error_msg)
+
+ # Check entries were updated
+ entry = self.logger_enabled.active_entries[request_id]
+ self.assertEqual(entry.error, error_msg)
+
+ # Finalize the conversation
+ self.logger_enabled.finalize_conversation(request_id)
+
+ # Should no longer be in active entries
+ self.assertNotIn(request_id, self.logger_enabled.active_entries)
+
+ # Should have written to file
+ log_files = list(self.temp_dir.glob("conversations_*.jsonl"))
+ self.assertEqual(len(log_files), 1)
+
+ # Read and verify log content
+ with open(log_files[0], 'r', encoding='utf-8') as f:
+ log_line = f.read().strip()
+
+ log_entry = json.loads(log_line)
+
+ # Verify structure
+ self.assertEqual(log_entry["request_id"], request_id)
+ self.assertEqual(log_entry["approach"], "moa")
+ self.assertEqual(log_entry["model"], "gpt-4o-mini")
+ self.assertEqual(log_entry["client_request"], client_request)
+ self.assertEqual(len(log_entry["provider_calls"]), 2)
+ self.assertEqual(log_entry["final_response"]["choices"][0]["message"]["content"], "final")
+ self.assertIsInstance(log_entry["total_duration_ms"], int)
+ self.assertEqual(log_entry["error"], error_msg)
+
+ def test_multiple_conversations_and_log_files(self):
+ """Test handling multiple concurrent conversations and log file naming"""
+ with patch('optillm.conversation_logger.datetime') as mock_datetime:
+ # Mock datetime.now to return a specific date
+ mock_datetime.now.return_value = datetime(2025, 1, 27, 12, 0, 0, tzinfo=timezone.utc)
+ mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)
+
+ # Test log file naming
+ log_path = self.logger_enabled._get_log_file_path()
+ expected_filename = "conversations_2025-01-27.jsonl"
+ self.assertEqual(log_path.name, expected_filename)
+ self.assertEqual(log_path.parent, self.temp_dir)
+
+ # Start multiple conversations
+ request_id1 = self.logger_enabled.start_conversation({}, "moa", "model1")
+ request_id2 = self.logger_enabled.start_conversation({}, "none", "model2")
+
+ # Should be different IDs
+ self.assertNotEqual(request_id1, request_id2)
+
+ # Both should be active
+ self.assertIn(request_id1, self.logger_enabled.active_entries)
+ self.assertIn(request_id2, self.logger_enabled.active_entries)
+
+ # Log different data to each
+ self.logger_enabled.log_provider_call(request_id1, {"req": "1"}, {"resp": "1"})
+ self.logger_enabled.log_provider_call(request_id2, {"req": "2"}, {"resp": "2"})
+
+ # Finalize both
+ self.logger_enabled.finalize_conversation(request_id1)
+ self.logger_enabled.finalize_conversation(request_id2)
+
+ # Should have 2 log entries in the file
+ log_files = list(self.temp_dir.glob("conversations_*.jsonl"))
+ self.assertEqual(len(log_files), 1)
+
+ with open(log_files[0], 'r', encoding='utf-8') as f:
+ lines = f.read().strip().split('\n')
+
+ self.assertEqual(len(lines), 2)
+
+ # Verify both entries
+ entry1 = json.loads(lines[0])
+ entry2 = json.loads(lines[1])
+
+ self.assertEqual(entry1["approach"], "moa")
+ self.assertEqual(entry2["approach"], "none")
+
+ def test_invalid_request_id_and_stats(self):
+ """Test handling of invalid request IDs and logger statistics"""
+ # Invalid request IDs should not raise errors but do nothing
+ self.logger_enabled.log_provider_call("invalid_id", {}, {})
+ self.logger_enabled.log_final_response("invalid_id", {})
+ self.logger_enabled.log_error("invalid_id", "error")
+ self.logger_enabled.finalize_conversation("invalid_id")
+
+ # Test disabled logger stats
+ stats = self.logger_disabled.get_stats()
+ expected_disabled_stats = {
+ "enabled": False,
+ "log_dir": str(self.temp_dir),
+ "active_conversations": 0
+ }
+ self.assertEqual(stats, expected_disabled_stats)
+
+ # Test enabled logger stats with active conversations
+ request_id1 = self.logger_enabled.start_conversation({}, "test", "model")
+ request_id2 = self.logger_enabled.start_conversation({}, "test", "model")
+
+ stats = self.logger_enabled.get_stats()
+
+ self.assertTrue(stats["enabled"])
+ self.assertEqual(stats["log_dir"], str(self.temp_dir))
+ self.assertEqual(stats["active_conversations"], 2)
+ self.assertEqual(stats["log_files_count"], 0) # No finalized conversations yet
+ self.assertEqual(stats["total_entries_approximate"], 0)
+
+ # Finalize one and check stats again
+ self.logger_enabled.finalize_conversation(request_id1)
+ stats = self.logger_enabled.get_stats()
+
+ self.assertEqual(stats["active_conversations"], 1)
+ self.assertEqual(stats["log_files_count"], 1)
+ self.assertEqual(stats["total_entries_approximate"], 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/test_conversation_logging_approaches.py b/tests/test_conversation_logging_approaches.py
new file mode 100644
index 00000000..43273346
--- /dev/null
+++ b/tests/test_conversation_logging_approaches.py
@@ -0,0 +1,428 @@
+#!/usr/bin/env python3
+"""
+Comprehensive tests for conversation logging across all approaches
+Tests ensure that all approaches properly log API calls without regressions
+"""
+
+import unittest
+import sys
+import os
+import json
+from unittest.mock import Mock, MagicMock, patch, call
+import tempfile
+from pathlib import Path
+
+# Add parent directory to path for imports
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import optillm
+from optillm.conversation_logger import ConversationLogger
+
+# Import all approaches we've modified
+from optillm.bon import best_of_n_sampling
+from optillm.mcts import chat_with_mcts
+from optillm.rto import round_trip_optimization
+from optillm.pvg import inference_time_pv_game
+from optillm.cot_reflection import cot_reflection
+from optillm.self_consistency import advanced_self_consistency_approach
+from optillm.reread import re2_approach
+from optillm.rstar import RStar
+from optillm.z3_solver import Z3SymPySolverSystem
+
+
+class MockOpenAIResponse:
+ """Mock OpenAI API response"""
+ def __init__(self, content="Test response", usage_tokens=10, n=1, call_index=0):
+ self.choices = []
+ for i in range(n):
+ choice = Mock()
+ choice.message = Mock()
+ # Make different content for different calls to avoid early returns
+ if call_index % 2 == 0:
+ choice.message.content = f"Code version A: {content} {i+1}" if n > 1 else f"Code version A: {content}"
+ else:
+ choice.message.content = f"Code version B: {content} {i+1}" if n > 1 else f"Code version B: {content}"
+ self.choices.append(choice)
+
+ self.usage = Mock()
+ self.usage.completion_tokens = usage_tokens
+ self.usage.completion_tokens_details = Mock()
+ self.usage.completion_tokens_details.reasoning_tokens = 0
+
+ def model_dump(self):
+ return {
+ "choices": [{"message": {"content": choice.message.content}} for choice in self.choices],
+ "usage": {"completion_tokens": self.usage.completion_tokens}
+ }
+
+
+class MockOpenAIClient:
+ """Mock OpenAI client for testing"""
+ def __init__(self, response_content="Test response", usage_tokens=10, n_responses=1):
+ self.chat = Mock()
+ self.chat.completions = Mock()
+ self.responses = []
+
+ # Create multiple responses if needed
+ for i in range(20): # Create enough responses for complex approaches
+ response = MockOpenAIResponse(response_content, usage_tokens, n_responses, i)
+ self.responses.append(response)
+
+ self.call_count = 0
+ self.chat.completions.create = self._create_response
+
+ def _create_response(self, **kwargs):
+ """Return the next response in sequence"""
+ response = self.responses[self.call_count % len(self.responses)]
+ self.call_count += 1
+ # Handle n parameter for BON approach
+ n = kwargs.get('n', 1)
+ if n > 1:
+ # Return a response with multiple choices for BON
+ return MockOpenAIResponse("Different response", 10, n, self.call_count)
+ return response
+
+
+class TestConversationLoggingApproaches(unittest.TestCase):
+ """Test conversation logging across all approaches"""
+
+ def setUp(self):
+ """Set up test environment"""
+ self.temp_dir = tempfile.mkdtemp()
+ self.log_dir = Path(self.temp_dir) / "conversations"
+ self.logger = ConversationLogger(self.log_dir, enabled=True)
+
+ # Mock optillm.conversation_logger
+ optillm.conversation_logger = self.logger
+
+ # Common test parameters
+ self.system_prompt = "You are a helpful assistant."
+ self.initial_query = "What is 2 + 2?"
+ self.model = "test-model"
+ self.request_id = "test-request-123"
+
+ # Create mock client
+ self.client = MockOpenAIClient()
+
+ def tearDown(self):
+ """Clean up test environment"""
+ import shutil
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
+ optillm.conversation_logger = None
+
+ def test_multi_call_approaches_logging(self):
+ """Test BON, MCTS, and RTO approaches log API calls correctly"""
+ # Test BON approach
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "bon",
+ self.model
+ )
+
+ result, tokens = best_of_n_sampling(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ n=2,
+ request_id=self.request_id
+ )
+
+ # BON makes multiple calls for sampling and rating
+ bon_calls = self.client.call_count
+ self.assertGreaterEqual(bon_calls, 2)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Reset client and test MCTS
+ self.client.call_count = 0
+ mcts_request_id = self.request_id + "_mcts"
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "mcts",
+ self.model
+ )
+
+ result, tokens = chat_with_mcts(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ num_simulations=2,
+ exploration_weight=0.2,
+ simulation_depth=1,
+ request_id=mcts_request_id
+ )
+
+ mcts_calls = self.client.call_count
+ self.assertGreaterEqual(mcts_calls, 1)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Reset client and test RTO
+ self.client.call_count = 0
+ rto_request_id = self.request_id + "_rto"
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "rto",
+ self.model
+ )
+
+ result, tokens = round_trip_optimization(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ request_id=rto_request_id
+ )
+
+ # RTO makes either 3 calls (if C1==C2) or 4 calls (C1 -> Q2 -> C2 -> C3)
+ rto_calls = self.client.call_count
+ self.assertGreaterEqual(rto_calls, 3)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ def test_single_call_approaches_logging(self):
+ """Test CoT Reflection and RE2 approaches log single API calls correctly"""
+ # Test CoT Reflection
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "cot_reflection",
+ self.model
+ )
+
+ result, tokens = cot_reflection(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ request_id=self.request_id
+ )
+
+ # CoT Reflection makes exactly 1 API call
+ self.assertEqual(self.client.call_count, 1)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Reset client and test RE2
+ self.client.call_count = 0
+ re2_request_id = self.request_id + "_re2"
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "re2",
+ self.model
+ )
+
+ result, tokens = re2_approach(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ n=1,
+ request_id=re2_request_id
+ )
+
+ # RE2 makes exactly 1 API call
+ self.assertEqual(self.client.call_count, 1)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ def test_sampling_approaches_logging(self):
+ """Test PVG and Self Consistency approaches log multiple sampling calls"""
+ # Test PVG approach
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "pvg",
+ self.model
+ )
+
+ result, tokens = inference_time_pv_game(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ num_rounds=1, # Reduce rounds for faster testing
+ num_solutions=2, # Reduce solutions for faster testing
+ request_id=self.request_id
+ )
+
+ # PVG makes multiple API calls: solutions + verifications + refinement
+ pvg_calls = self.client.call_count
+ self.assertGreaterEqual(pvg_calls, 3)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Reset client and test Self Consistency
+ self.client.call_count = 0
+ sc_request_id = self.request_id + "_sc"
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "self_consistency",
+ self.model
+ )
+
+ result, tokens = advanced_self_consistency_approach(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ request_id=sc_request_id
+ )
+
+ # Self Consistency makes num_samples API calls (default 5)
+ self.assertEqual(self.client.call_count, 5)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ @patch('optillm.z3_solver.multiprocessing.get_context')
+ def test_complex_class_based_approaches_logging(self, mock_mp_context):
+ """Test RStar and Z3 Solver class-based approaches log API calls correctly"""
+ # Test RStar approach
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "rstar",
+ self.model
+ )
+
+ rstar = RStar(
+ self.system_prompt,
+ self.client,
+ self.model,
+ max_depth=2, # Reduce depth for faster testing
+ num_rollouts=2, # Reduce rollouts for faster testing
+ request_id=self.request_id
+ )
+
+ result, tokens = rstar.solve(self.initial_query)
+
+ # RStar makes multiple API calls during MCTS rollouts
+ rstar_calls = self.client.call_count
+ self.assertGreaterEqual(rstar_calls, 1)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Reset client and test Z3 Solver
+ self.client.call_count = 0
+ z3_request_id = self.request_id + "_z3"
+ self.logger.start_conversation(
+ {"model": self.model, "messages": []},
+ "z3",
+ self.model
+ )
+
+ # Mock multiprocessing for z3_solver
+ mock_pool = Mock()
+ mock_result = Mock()
+ mock_result.get.return_value = ("success", "Test solver output")
+ mock_pool.apply_async.return_value = mock_result
+ mock_context = Mock()
+ mock_context.Pool.return_value = MagicMock() # Use MagicMock for context manager
+ mock_context.Pool.return_value.__enter__.return_value = mock_pool
+ mock_context.Pool.return_value.__exit__.return_value = None
+ mock_mp_context.return_value = mock_context
+
+ z3_solver = Z3SymPySolverSystem(
+ self.system_prompt,
+ self.client,
+ self.model,
+ request_id=z3_request_id
+ )
+
+ result, tokens = z3_solver.process_query(self.initial_query)
+
+ # Z3 Solver makes at least 1 API call for analysis
+ self.assertGreaterEqual(self.client.call_count, 1)
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ def test_logging_edge_cases(self):
+ """Test approaches work with logging disabled, no request_id, and API errors"""
+ # Test with logging disabled
+ optillm.conversation_logger = None
+
+ result, tokens = best_of_n_sampling(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ n=2,
+ request_id=self.request_id
+ )
+
+ # Should still work normally
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Re-enable logging for next test
+ optillm.conversation_logger = self.logger
+
+ # Test with no request_id
+ self.client.call_count = 0
+ result, tokens = cot_reflection(
+ self.system_prompt,
+ self.initial_query,
+ self.client,
+ self.model,
+ request_id=None
+ )
+
+ # Should still work normally
+ self.assertIsInstance(result, str)
+ self.assertGreater(tokens, 0)
+
+ # Test API error handling
+ error_client = Mock()
+ error_client.chat.completions.create.side_effect = Exception("API Error")
+
+ # Test that approaches handle errors gracefully
+ with self.assertRaises(Exception):
+ cot_reflection(
+ self.system_prompt,
+ self.initial_query,
+ error_client,
+ self.model,
+ request_id=self.request_id
+ )
+
+
+ def test_full_integration_with_file_logging(self):
+ """Test complete integration from approach execution to file logging"""
+ # Start conversation and get request_id
+ request_id = self.logger.start_conversation(
+ {"model": "test-model", "messages": []},
+ "bon",
+ "test-model"
+ )
+
+ # Run approach with the returned request_id
+ result, tokens = best_of_n_sampling(
+ "You are a helpful assistant.",
+ "What is 2 + 2?",
+ self.client,
+ "test-model",
+ n=2,
+ request_id=request_id
+ )
+
+ # Finalize conversation
+ self.logger.finalize_conversation(request_id)
+
+ # Check that conversation was logged
+ log_files = list(self.log_dir.glob("*.jsonl"))
+ self.assertGreater(len(log_files), 0)
+
+ # Check that log file contains entries
+ with open(log_files[0], 'r') as f:
+ lines = f.readlines()
+ self.assertGreater(len(lines), 0)
+
+ # Verify log entry structure
+ log_entry = json.loads(lines[0].strip())
+ self.assertEqual(log_entry["approach"], "bon")
+ self.assertIn("provider_calls", log_entry)
+ self.assertGreater(len(log_entry["provider_calls"]), 0)
+
+
+if __name__ == "__main__":
+ # Configure test runner
+ unittest.main(verbosity=2, buffer=True)
\ No newline at end of file
diff --git a/tests/test_conversation_logging_server.py b/tests/test_conversation_logging_server.py
new file mode 100644
index 00000000..68a09c7c
--- /dev/null
+++ b/tests/test_conversation_logging_server.py
@@ -0,0 +1,572 @@
+#!/usr/bin/env python3
+"""
+Server-based integration tests for conversation logging with real model
+Tests conversation logging with actual OptILLM server and google/gemma-3-270m-it model
+"""
+
+import unittest
+import sys
+import os
+import requests
+import json
+import tempfile
+import time
+import subprocess
+from pathlib import Path
+from openai import OpenAI
+
+# Add parent directory to path for imports
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from test_utils import TEST_MODEL, setup_test_env, start_test_server, stop_test_server
+
+
+class TestConversationLoggingWithServer(unittest.TestCase):
+ """Integration tests with real OptILLM server and conversation logging"""
+
+ @classmethod
+ def setUpClass(cls):
+ """Set up OptILLM server for testing"""
+ setup_test_env()
+
+ # Check if server is already running
+ cls.server_available = cls._check_existing_server()
+ cls.server_process = None
+ cls.temp_log_dir = None
+
+ if not cls.server_available:
+ # Start our own server with logging enabled
+ cls.temp_log_dir = Path(tempfile.mkdtemp())
+ cls.server_process = cls._start_server_with_logging()
+
+ # Wait for server to be ready
+ max_wait = 30 # seconds
+ start_time = time.time()
+
+ while time.time() - start_time < max_wait:
+ if cls._check_server_health():
+ cls.server_available = True
+ break
+ time.sleep(1)
+
+ if not cls.server_available:
+ if cls.server_process:
+ stop_test_server(cls.server_process)
+ raise unittest.SkipTest("Could not start OptILLM server for testing")
+
+ @classmethod
+ def tearDownClass(cls):
+ """Clean up server"""
+ if cls.server_process:
+ stop_test_server(cls.server_process)
+
+ if cls.temp_log_dir and cls.temp_log_dir.exists():
+ import shutil
+ shutil.rmtree(cls.temp_log_dir, ignore_errors=True)
+
+ @staticmethod
+ def _check_existing_server():
+ """Check if OptILLM server is already running"""
+ try:
+ response = requests.get("http://localhost:8000/v1/health", timeout=2)
+ return response.status_code == 200
+ except requests.exceptions.RequestException:
+ return False
+
+ @staticmethod
+ def _check_server_health():
+ """Check if server is healthy"""
+ try:
+ response = requests.get("http://localhost:8000/v1/health", timeout=5)
+ return response.status_code == 200
+ except requests.exceptions.RequestException:
+ return False
+
+ @classmethod
+ def _start_server_with_logging(cls):
+ """Start server with conversation logging enabled"""
+ env = os.environ.copy()
+ env["OPTILLM_API_KEY"] = "optillm"
+ env["OPTILLM_LOG_CONVERSATIONS"] = "true"
+ env["OPTILLM_CONVERSATION_LOG_DIR"] = str(cls.temp_log_dir)
+
+ proc = subprocess.Popen([
+ sys.executable, "optillm.py",
+ "--model", TEST_MODEL,
+ "--port", "8000",
+ "--log-conversations",
+ "--conversation-log-dir", str(cls.temp_log_dir)
+ ], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ return proc
+
+ def setUp(self):
+ """Set up test client"""
+ if not self.server_available:
+ self.skipTest("OptILLM server not available")
+
+ self.client = OpenAI(api_key="optillm", base_url="http://localhost:8000/v1")
+
+ # Determine log directory - use temp dir if we started server, otherwise default
+ if self.temp_log_dir:
+ self.log_dir = self.temp_log_dir
+ else:
+ self.log_dir = Path.home() / ".optillm" / "conversations"
+
+ # Record initial state for comparison
+ self.initial_log_files = set(self.log_dir.glob("*.jsonl")) if self.log_dir.exists() else set()
+
+ def _get_new_log_entries(self):
+ """Get new log entries since test started"""
+ if not self.log_dir.exists():
+ return []
+
+ current_log_files = set(self.log_dir.glob("*.jsonl"))
+ new_files = current_log_files - self.initial_log_files
+ modified_files = [f for f in self.initial_log_files if f in current_log_files and f.stat().st_mtime > time.time() - 60]
+
+ entries = []
+ for log_file in new_files.union(set(modified_files)):
+ try:
+ with open(log_file, 'r') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ entries.append(json.loads(line))
+ except (json.JSONDecodeError, IOError):
+ continue
+
+ return entries
+
+ def test_basic_none_approach_logging(self):
+ """Test basic none approach with conversation logging"""
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is 2 + 2? Answer with just the number."}
+ ],
+ max_tokens=10
+ )
+
+ # Verify response
+ self.assertIsNotNone(response)
+ self.assertGreater(len(response.choices), 0)
+ self.assertIsNotNone(response.choices[0].message.content)
+
+ # Wait for logging
+ time.sleep(2)
+
+ # Check for new log entries
+ entries = self._get_new_log_entries()
+ self.assertGreater(len(entries), 0, "No log entries found for basic none approach")
+
+ # Verify at least one entry has the expected structure
+ found_entry = False
+ for entry in entries:
+ if entry.get("approach") == "none" and entry.get("model") == TEST_MODEL:
+ found_entry = True
+ self.assertIn("provider_calls", entry)
+ self.assertIn("client_request", entry)
+ self.assertIn("timestamp", entry)
+ break
+
+ self.assertTrue(found_entry, "No valid log entry found for none approach")
+
+ def test_re2_approach_logging(self):
+ """Test RE2 approach with conversation logging"""
+ response = self.client.chat.completions.create(
+ model=f"re2-{TEST_MODEL}",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France? Answer in one word."}
+ ],
+ max_tokens=10
+ )
+
+ # Verify response
+ self.assertIsNotNone(response)
+ self.assertGreater(len(response.choices), 0)
+
+ # Wait for logging
+ time.sleep(3)
+
+ # Check for new log entries
+ entries = self._get_new_log_entries()
+
+ # Find RE2 entry
+ re2_entry = None
+ for entry in entries:
+ if entry.get("approach") == "re2":
+ re2_entry = entry
+ break
+
+ self.assertIsNotNone(re2_entry, "No RE2 log entry found")
+ self.assertEqual(re2_entry["model"], TEST_MODEL)
+ self.assertIn("provider_calls", re2_entry)
+ self.assertGreaterEqual(len(re2_entry["provider_calls"]), 1)
+
+ def test_cot_reflection_approach_logging(self):
+ """Test CoT Reflection approach with conversation logging"""
+ response = self.client.chat.completions.create(
+ model=f"cot_reflection-{TEST_MODEL}",
+ messages=[
+ {"role": "system", "content": "Think step by step."},
+ {"role": "user", "content": "What is 3 Ć 4? Show your work."}
+ ],
+ max_tokens=50
+ )
+
+ # Verify response
+ self.assertIsNotNone(response)
+ self.assertGreater(len(response.choices), 0)
+
+ # Wait for logging
+ time.sleep(3)
+
+ # Check for log entries
+ entries = self._get_new_log_entries()
+
+ # Find CoT reflection entry
+ cot_entry = None
+ for entry in entries:
+ if entry.get("approach") == "cot_reflection":
+ cot_entry = entry
+ break
+
+ self.assertIsNotNone(cot_entry, "No CoT reflection log entry found")
+ self.assertEqual(cot_entry["model"], TEST_MODEL)
+ self.assertIn("provider_calls", cot_entry)
+ self.assertGreaterEqual(len(cot_entry["provider_calls"]), 1)
+
+ def test_extra_body_approach_logging(self):
+ """Test approach specification via extra_body parameter"""
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Test extra_body. Reply with 'OK'."}
+ ],
+ extra_body={"optillm_approach": "re2"},
+ max_tokens=10
+ )
+
+ # Verify response
+ self.assertIsNotNone(response)
+ self.assertGreater(len(response.choices), 0)
+
+ # Wait for logging
+ time.sleep(3)
+
+ # Check for log entries
+ entries = self._get_new_log_entries()
+
+ # Find entry with RE2 approach (specified via extra_body)
+ found_entry = False
+ for entry in entries:
+ if entry.get("approach") == "re2" and entry.get("model") == TEST_MODEL:
+ found_entry = True
+ self.assertIn("provider_calls", entry)
+ break
+
+ self.assertTrue(found_entry, "No log entry found for extra_body approach specification")
+
+ def test_reasoning_tokens_logging(self):
+ """Test that reasoning tokens are properly logged"""
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "Think step by step and show reasoning."},
+ {"role": "user", "content": "What is 5 + 7? Explain your thinking."}
+ ],
+ max_tokens=100
+ )
+
+ # Verify response structure
+ self.assertIsNotNone(response)
+ self.assertIsNotNone(response.usage)
+
+ # Wait for logging
+ time.sleep(3)
+
+ # Check for log entries with usage information
+ entries = self._get_new_log_entries()
+
+ found_usage_entry = False
+ for entry in entries:
+ if "provider_calls" in entry and len(entry["provider_calls"]) > 0:
+ for call in entry["provider_calls"]:
+ if "response" in call and "usage" in call["response"]:
+ found_usage_entry = True
+ usage = call["response"]["usage"]
+ self.assertIn("completion_tokens", usage)
+ # reasoning_tokens might be 0 or missing for this simple model
+ if "completion_tokens_details" in usage:
+ details = usage["completion_tokens_details"]
+ if "reasoning_tokens" in details:
+ self.assertIsInstance(details["reasoning_tokens"], int)
+ break
+ if found_usage_entry:
+ break
+
+ self.assertTrue(found_usage_entry, "No log entry with usage information found")
+
+ def test_multiple_approaches_logging(self):
+ """Test multiple different approaches get logged correctly"""
+ approaches_to_test = [
+ ("none", TEST_MODEL),
+ ("re2", f"re2-{TEST_MODEL}"),
+ ("cot_reflection", f"cot_reflection-{TEST_MODEL}")
+ ]
+
+ responses = []
+ for approach_name, model_name in approaches_to_test:
+ try:
+ response = self.client.chat.completions.create(
+ model=model_name,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": f"Test {approach_name}. Reply 'OK'."}
+ ],
+ max_tokens=10
+ )
+ responses.append((approach_name, response))
+ time.sleep(1) # Brief pause between requests
+ except Exception as e:
+ self.fail(f"Approach {approach_name} failed: {e}")
+
+ # Verify all responses
+ self.assertEqual(len(responses), 3)
+ for approach_name, response in responses:
+ self.assertIsNotNone(response)
+ self.assertGreater(len(response.choices), 0)
+
+ # Wait for logging
+ time.sleep(5)
+
+ # Check for log entries
+ entries = self._get_new_log_entries()
+
+ # Find entries for each approach
+ found_approaches = set()
+ for entry in entries:
+ approach = entry.get("approach")
+ if approach in ["none", "re2", "cot_reflection"]:
+ found_approaches.add(approach)
+ self.assertEqual(entry["model"], TEST_MODEL)
+ self.assertIn("provider_calls", entry)
+
+ # Should have logged all 3 approaches
+ self.assertGreaterEqual(len(found_approaches), 2, f"Not all approaches logged. Found: {found_approaches}")
+
+ def test_concurrent_requests_logging(self):
+ """Test that concurrent requests are logged properly"""
+ import threading
+ import queue
+
+ results = queue.Queue()
+
+ def make_request(index):
+ try:
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": f"Concurrent test {index}. Reply with the number {index}."}
+ ],
+ max_tokens=10
+ )
+ results.put(("success", index, response))
+ except Exception as e:
+ results.put(("error", index, str(e)))
+
+ # Start multiple concurrent requests
+ threads = []
+ for i in range(3):
+ thread = threading.Thread(target=make_request, args=(i,))
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join(timeout=30)
+
+ # Collect results
+ successful_requests = []
+ while not results.empty():
+ result_type, index, result = results.get()
+ if result_type == "success":
+ successful_requests.append((index, result))
+ else:
+ self.fail(f"Concurrent request {index} failed: {result}")
+
+ self.assertGreaterEqual(len(successful_requests), 2, "Not enough concurrent requests succeeded")
+
+ # Wait for logging
+ time.sleep(5)
+
+ # Check for log entries
+ entries = self._get_new_log_entries()
+
+ # Should have entries for concurrent requests
+ concurrent_entries = [e for e in entries if "Concurrent test" in str(e.get("client_request", {}))]
+ self.assertGreaterEqual(len(concurrent_entries), 2, "Not enough concurrent request log entries found")
+
+ def test_error_handling_logging(self):
+ """Test that errors in approaches are properly logged"""
+ # Make request that might cause issues (very low max_tokens)
+ try:
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "This is a test for error logging scenarios."}
+ ],
+ max_tokens=1 # Very low to potentially cause issues
+ )
+ # Even if it succeeds, that's fine for this test
+ self.assertIsNotNone(response)
+ except Exception:
+ # Exception is also fine for this test
+ pass
+
+ # Wait for logging
+ time.sleep(3)
+
+ # Check that some logging occurred (success or error)
+ entries = self._get_new_log_entries()
+
+ # Should have at least some entry (success or partial)
+ found_relevant_entry = False
+ for entry in entries:
+ if "error logging scenarios" in str(entry.get("client_request", {})):
+ found_relevant_entry = True
+ break
+
+ # Even if no specific entry found, logging system should be working
+ # (this test mainly ensures no crashes in error scenarios)
+ self.assertGreaterEqual(len(entries), 0, "No log entries found (system may have crashed)")
+
+ def test_log_file_structure_and_format(self):
+ """Test that log files have correct JSONL structure and required fields"""
+ # Make a simple request
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Structure test. Reply 'STRUCTURE_OK'."}
+ ],
+ max_tokens=15
+ )
+
+ self.assertIsNotNone(response)
+
+ # Wait for logging
+ time.sleep(3)
+
+ # Check for log entries
+ entries = self._get_new_log_entries()
+
+ # Find relevant entry
+ relevant_entry = None
+ for entry in entries:
+ if "STRUCTURE_OK" in str(entry.get("client_request", {})) or "Structure test" in str(entry.get("client_request", {})):
+ relevant_entry = entry
+ break
+
+ if not relevant_entry and entries:
+ # Use any recent entry for structure validation
+ relevant_entry = entries[0]
+
+ self.assertIsNotNone(relevant_entry, "No log entry found for structure validation")
+
+ # Verify required fields in consolidated format
+ required_fields = [
+ "timestamp", "request_id", "approach", "model",
+ "client_request", "provider_calls"
+ ]
+
+ for field in required_fields:
+ self.assertIn(field, relevant_entry, f"Missing required field: {field}")
+
+ # Verify provider calls structure
+ provider_calls = relevant_entry["provider_calls"]
+ self.assertIsInstance(provider_calls, list)
+ self.assertGreater(len(provider_calls), 0, "No provider calls logged")
+
+ for call in provider_calls:
+ self.assertIn("request", call)
+ self.assertIn("response", call)
+ self.assertIn("timestamp", call)
+ self.assertIn("call_number", call)
+
+ # Verify timestamps are valid
+ self.assertIsInstance(relevant_entry["timestamp"], str)
+ for call in provider_calls:
+ self.assertIsInstance(call["timestamp"], str)
+
+
+@unittest.skipUnless(
+ os.getenv("OPTILLM_API_KEY") == "optillm",
+ "Set OPTILLM_API_KEY=optillm to run server-based tests"
+)
+class TestConversationLoggingPerformanceWithServer(unittest.TestCase):
+ """Performance tests with real server"""
+
+ def setUp(self):
+ """Check server availability"""
+ if not requests.get("http://localhost:8000/v1/health", timeout=2).status_code == 200:
+ self.skipTest("OptILLM server not available")
+
+ self.client = OpenAI(api_key="optillm", base_url="http://localhost:8000/v1")
+
+ def test_logging_performance_impact(self):
+ """Test that logging doesn't significantly impact response time"""
+ import time
+
+ # Warm up
+ self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[{"role": "user", "content": "warmup"}],
+ max_tokens=5
+ )
+
+ # Time multiple requests
+ times = []
+ for i in range(5):
+ start_time = time.perf_counter()
+ response = self.client.chat.completions.create(
+ model=TEST_MODEL,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": f"Performance test {i}. Reply 'OK'."}
+ ],
+ max_tokens=5
+ )
+ end_time = time.perf_counter()
+
+ # Verify response
+ self.assertIsNotNone(response)
+
+ times.append(end_time - start_time)
+
+ # Calculate average time
+ avg_time = sum(times) / len(times)
+
+ # Should be reasonably fast (under 10 seconds for small model)
+ self.assertLess(avg_time, 10.0, f"Average response time too slow: {avg_time:.2f}s")
+
+ print(f"\nš Server Performance with Logging:")
+ print(f" Average response time: {avg_time:.3f}s")
+ print(f" Response times: {[f'{t:.3f}s' for t in times]}")
+
+
+if __name__ == "__main__":
+ print("š Running conversation logging server-based integration tests...")
+ print("=" * 70)
+ print("These tests require an actual OptILLM server with logging enabled.")
+ print("Set OPTILLM_API_KEY=optillm to run all tests.")
+ print()
+
+ # Run tests
+ unittest.main(verbosity=2, buffer=True)
\ No newline at end of file