Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ and then use the same in your OpenAI client. You can pass any HuggingFace model
with your HuggingFace key. We also support adding any number of LoRAs on top of the model by using the `+` separator.

E.g. The following code loads the base model `meta-llama/Llama-3.2-1B-Instruct` and then adds two LoRAs on top - `patched-codes/Llama-3.2-1B-FixVulns` and `patched-codes/Llama-3.2-1B-FastApply`.
You can specify which LoRA to use using the `active_adapter` param in `extra_args` field of OpenAI SDK client. By default we will load the last specified adapter.
You can specify which LoRA to use using the `active_adapter` param in `extra_body` field of OpenAI SDK client. By default we will load the last specified adapter.

```python
OPENAI_BASE_URL = "http://localhost:8000/v1"
Expand Down Expand Up @@ -748,4 +748,4 @@ If you use this library in your research, please cite:

<p align="center">
⭐ <a href="https://github.com/codelion/optillm">Star us on GitHub</a> if you find OptiLLM useful!
</p>
</p>
109 changes: 85 additions & 24 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import secrets
import time
from pathlib import Path
from flask import Flask, request, jsonify
from cerebras.cloud.sdk import Cerebras
from openai import AzureOpenAI, OpenAI
Expand Down Expand Up @@ -32,6 +33,7 @@
from optillm.reread import re2_approach
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
from optillm.batching import RequestBatcher, BatchingError
from optillm.conversation_logger import ConversationLogger

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand All @@ -51,6 +53,9 @@
# Global request batcher (initialized in main() if batch mode enabled)
request_batcher = None

# Global conversation logger (initialized in main() if logging enabled)
conversation_logger = None

def get_config():
API_KEY = None
if os.environ.get("OPTILLM_API_KEY"):
Expand Down Expand Up @@ -196,17 +201,17 @@ def none_approach(
client: Any,
model: str,
original_messages: List[Dict[str, str]],
request_id: str = None,
**kwargs
) -> Dict[str, Any]:
"""
Direct proxy approach that passes through all parameters to the underlying endpoint.

Args:
system_prompt: System prompt text (unused)
initial_query: Initial query/conversation (unused)
client: OpenAI client instance
model: Model identifier
original_messages: Original messages from the request
request_id: Optional request ID for conversation logging
**kwargs: Additional parameters to pass through

Returns:
Expand All @@ -220,6 +225,13 @@ def none_approach(
# Normalize message content to ensure it's always string
normalized_messages = normalize_message_content(original_messages)

# Prepare request data for logging
provider_request = {
"model": model,
"messages": normalized_messages,
**kwargs
}

# Make the direct completion call with normalized messages and parameters
response = client.chat.completions.create(
model=model,
Expand All @@ -228,11 +240,18 @@ def none_approach(
)

# Convert to dict if it's not already
if hasattr(response, 'model_dump'):
return response.model_dump()
return response
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response

# Log the provider call if conversation logging is enabled
if conversation_logger and request_id:
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

return response_dict

except Exception as e:
# Log error if conversation logging is enabled
if conversation_logger and request_id:
conversation_logger.log_error(request_id, f"Error in none approach: {str(e)}")
logger.error(f"Error in none approach: {str(e)}")
raise

Expand Down Expand Up @@ -345,7 +364,7 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache

return operation, approaches, actual_model

def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None):
def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None, request_id: str = None):
if approach in known_approaches:
if approach == 'none':
# Extract kwargs from the request data
Expand All @@ -356,41 +375,41 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
# Copy all parameters except 'stream', 'model' and 'messages'
kwargs = {k: v for k, v in data.items()
if k not in ['model', 'messages', 'stream', 'optillm_approach']}
response = none_approach(original_messages=messages, client=client, model=model, **kwargs)
response = none_approach(original_messages=messages, client=client, model=model, request_id=request_id, **kwargs)
# For none approach, we return the response and a token count of 0
# since the full token count is already in the response
return response, 0
elif approach == 'mcts':
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
server_config['mcts_exploration'], server_config['mcts_depth'], request_id)
elif approach == 'bon':
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_id)
elif approach == 'moa':
return mixture_of_agents(system_prompt, initial_query, client, model)
return mixture_of_agents(system_prompt, initial_query, client, model, request_id)
elif approach == 'rto':
return round_trip_optimization(system_prompt, initial_query, client, model)
return round_trip_optimization(system_prompt, initial_query, client, model, request_id)
elif approach == 'z3':
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_id=request_id)
return z3_solver.process_query(initial_query)
elif approach == "self_consistency":
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_id)
elif approach == "pvg":
return inference_time_pv_game(system_prompt, initial_query, client, model)
return inference_time_pv_game(system_prompt, initial_query, client, model, request_id)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
c=server_config['rstar_c'])
c=server_config['rstar_c'], request_id=request_id)
return rstar.solve(initial_query)
elif approach == "cot_reflection":
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config)
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config, request_id=request_id)
elif approach == 'plansearch':
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id)
elif approach == 'leap':
return leap(system_prompt, initial_query, client, model)
return leap(system_prompt, initial_query, client, model, request_id)
elif approach == 're2':
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id)
elif approach == 'cepo':
return cepo(system_prompt, initial_query, client, model, cepo_config)
return cepo(system_prompt, initial_query, client, model, cepo_config, request_id)
elif approach in plugin_approaches:
# Check if the plugin accepts request_config
plugin_func = plugin_approaches[approach]
Expand Down Expand Up @@ -445,7 +464,7 @@ async def run_approach(approach):
return list(responses), sum(tokens)

def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str,
request_config: dict = None) -> Tuple[Union[str, List[str]], int]:
request_config: dict = None, request_id: str = None) -> Tuple[Union[str, List[str]], int]:
"""
Execute the pipeline n times and return n responses.

Expand All @@ -466,7 +485,7 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init

for _ in range(n):
if operation == 'SINGLE':
response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id)
elif operation == 'AND':
response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config)
elif operation == 'OR':
Expand Down Expand Up @@ -678,6 +697,21 @@ def proxy():
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')

# Start conversation logging if enabled
request_id = None
if conversation_logger and conversation_logger.enabled:
request_id = conversation_logger.start_conversation(
client_request={
'messages': messages,
'model': data.get('model', server_config['model']),
'stream': stream,
'n': n,
**{k: v for k, v in data.items() if k not in {'messages', 'model', 'stream', 'n'}}
},
approach=approaches[0] if len(approaches) == 1 else f"{operation}({','.join(approaches)})",
model=model
)

if bearer_token != "" and bearer_token.startswith("sk-"):
api_key = bearer_token
if base_url != "":
Expand Down Expand Up @@ -718,10 +752,15 @@ def proxy():

if operation == 'SINGLE' and approaches[0] == 'none':
# Pass through the request including the n parameter
result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id)

logger.debug(f'Direct proxy response: {result}')

# Log the final response and finalize conversation logging
if conversation_logger and request_id:
conversation_logger.log_final_response(request_id, result)
conversation_logger.finalize_conversation(request_id)

if stream:
return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream')
else :
Expand All @@ -732,9 +771,14 @@ def proxy():
raise ValueError("'none' approach cannot be combined with other approaches")

# Handle non-none approaches with n attempts
response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config)
response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config, request_id)

except Exception as e:
# Log error to conversation logger if enabled
if conversation_logger and request_id:
conversation_logger.log_error(request_id, str(e))
conversation_logger.finalize_conversation(request_id)

logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500

Expand Down Expand Up @@ -793,6 +837,11 @@ def proxy():
'finish_reason': 'stop'
})

# Log the final response and finalize conversation logging
if conversation_logger and request_id:
conversation_logger.log_final_response(request_id, response_data)
conversation_logger.finalize_conversation(request_id)

logger.debug(f'API response: {response_data}')
return jsonify(response_data), 200

Expand Down Expand Up @@ -848,6 +897,8 @@ def parse_args():
("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys())),
("--launch-gui", "OPTILLM_LAUNCH_GUI", bool, False, "Launch a Gradio chat interface"),
("--plugins-dir", "OPTILLM_PLUGINS_DIR", str, "", "Path to the plugins directory"),
("--log-conversations", "OPTILLM_LOG_CONVERSATIONS", bool, False, "Enable conversation logging with full metadata"),
("--conversation-log-dir", "OPTILLM_CONVERSATION_LOG_DIR", str, str(Path.home() / ".optillm" / "conversations"), "Directory to save conversation logs"),
]

for arg, env, type_, default, help_text, *extra in args_env:
Expand Down Expand Up @@ -920,6 +971,7 @@ def main():
global server_config
global cepo_config
global request_batcher
global conversation_logger
# Call this function at the start of main()
args = parse_args()
# Update server_config with all argument values
Expand Down Expand Up @@ -1075,6 +1127,15 @@ def process_batch_requests(batch_requests):
if logging_level in logging_levels.keys():
logger.setLevel(logging_levels[logging_level])

# Initialize conversation logger if enabled
global conversation_logger
conversation_logger = ConversationLogger(
log_dir=Path(server_config['conversation_log_dir']),
enabled=server_config['log_conversations']
)
if server_config['log_conversations']:
logger.info(f"Conversation logging enabled. Logs will be saved to: {server_config['conversation_log_dir']}")

# set and log the cepo configs
cepo_config = init_cepo_config(server_config)
if args.approach == 'cepo':
Expand Down
64 changes: 43 additions & 21 deletions optillm/bon.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import optillm

logger = logging.getLogger(__name__)

def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str:
bon_completion_tokens = 0

messages = [{"role": "system", "content": system_prompt},
Expand All @@ -12,13 +13,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st

try:
# Try to generate n completions in a single API call using n parameter
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=4096,
n=n,
temperature=1
)
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"n": n,
"temperature": 1
}
response = client.chat.completions.create(**provider_request)

# Log provider call
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)

completions = [choice.message.content for choice in response.choices]
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
bon_completion_tokens += response.usage.completion_tokens
Expand All @@ -30,12 +38,19 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
# Fallback: Generate completions one by one in a loop
for i in range(n):
try:
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=4096,
temperature=1
)
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"temperature": 1
}
response = client.chat.completions.create(**provider_request)

# Log provider call
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)

completions.append(response.choices[0].message.content)
bon_completion_tokens += response.usage.completion_tokens
logger.debug(f"Generated completion {i+1}/{n}")
Expand All @@ -59,13 +74,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
rating_messages.append({"role": "assistant", "content": completion})
rating_messages.append({"role": "user", "content": "Rate the above response:"})

rating_response = client.chat.completions.create(
model=model,
messages=rating_messages,
max_tokens=256,
n=1,
temperature=0.1
)
provider_request = {
"model": model,
"messages": rating_messages,
"max_tokens": 256,
"n": 1,
"temperature": 0.1
}
rating_response = client.chat.completions.create(**provider_request)

# Log provider call
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)

bon_completion_tokens += rating_response.usage.completion_tokens
try:
rating = float(rating_response.choices[0].message.content.strip())
Expand Down
Loading