Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optillm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Version information
__version__ = "0.3.3"
__version__ = "0.3.4"

# Import from server module
from .server import (
Expand Down
58 changes: 44 additions & 14 deletions optillm/bon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,24 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
"temperature": 1
}
response = client.chat.completions.create(**provider_request)

# Log provider call
if request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

completions = [choice.message.content for choice in response.choices]

# Check for valid response with None-checking
if response is None or not response.choices:
raise Exception("Response is None or has no choices")

completions = [choice.message.content for choice in response.choices if choice.message.content is not None]
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
bon_completion_tokens += response.usage.completion_tokens


# Check if any valid completions were generated
if not completions:
raise Exception("No valid completions generated (all were None)")

except Exception as e:
logger.warning(f"n parameter not supported by provider: {str(e)}")
logger.info(f"Falling back to generating {n} completions one by one")
Expand All @@ -46,12 +54,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
"temperature": 1
}
response = client.chat.completions.create(**provider_request)

# Log provider call
if request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(request_id, provider_request, response_dict)


# Check for valid response with None-checking
if (response is None or
not response.choices or
response.choices[0].message.content is None or
response.choices[0].finish_reason == "length"):
logger.warning(f"Completion {i+1}/{n} truncated or empty, skipping")
continue

completions.append(response.choices[0].message.content)
bon_completion_tokens += response.usage.completion_tokens
logger.debug(f"Generated completion {i+1}/{n}")
Expand All @@ -65,11 +81,16 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
return "Error: Could not generate any completions", 0

logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}")


# Double-check we have completions before rating
if not completions:
logger.error("No completions available for rating")
return "Error: Could not generate any completions", bon_completion_tokens

# Rate the completions
rating_messages = messages.copy()
rating_messages.append({"role": "system", "content": "Rate the following responses on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number."})

ratings = []
for completion in completions:
rating_messages.append({"role": "assistant", "content": completion})
Expand All @@ -83,18 +104,27 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
"temperature": 0.1
}
rating_response = client.chat.completions.create(**provider_request)

# Log provider call
if request_id:
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

bon_completion_tokens += rating_response.usage.completion_tokens
try:
rating = float(rating_response.choices[0].message.content.strip())
ratings.append(rating)
except ValueError:

# Check for valid response with None-checking
if (rating_response is None or
not rating_response.choices or
rating_response.choices[0].message.content is None or
rating_response.choices[0].finish_reason == "length"):
logger.warning("Rating response truncated or empty, using default rating of 0")
ratings.append(0)
else:
try:
rating = float(rating_response.choices[0].message.content.strip())
ratings.append(rating)
except ValueError:
ratings.append(0)

rating_messages = rating_messages[:-2]

Expand Down
39 changes: 31 additions & 8 deletions optillm/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,18 @@ def generate_actions(self, state: DialogueState) -> List[str]:
"temperature": 1
}
response = self.client.chat.completions.create(**provider_request)

# Log provider call
if self.request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)

completions = [choice.message.content.strip() for choice in response.choices]

# Check for valid response with None-checking
if response is None or not response.choices:
logger.error("Failed to get valid completions from the model")
return []

completions = [choice.message.content.strip() for choice in response.choices if choice.message.content is not None]
self.completion_tokens += response.usage.completion_tokens
logger.info(f"Received {len(completions)} completions from the model")
return completions
Expand All @@ -151,13 +156,22 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
"temperature": 1
}
response = self.client.chat.completions.create(**provider_request)

# Log provider call
if self.request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)

next_query = response.choices[0].message.content

# Check for valid response with None-checking
if (response is None or
not response.choices or
response.choices[0].message.content is None or
response.choices[0].finish_reason == "length"):
logger.warning("Next query response truncated or empty, using default")
next_query = "Please continue."
else:
next_query = response.choices[0].message.content

self.completion_tokens += response.usage.completion_tokens
logger.info(f"Generated next user query: {next_query}")
return DialogueState(state.system_prompt, new_history, next_query)
Expand All @@ -181,13 +195,22 @@ def evaluate_state(self, state: DialogueState) -> float:
"temperature": 0.1
}
response = self.client.chat.completions.create(**provider_request)

# Log provider call
if self.request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)

self.completion_tokens += response.usage.completion_tokens

# Check for valid response with None-checking
if (response is None or
not response.choices or
response.choices[0].message.content is None or
response.choices[0].finish_reason == "length"):
logger.warning("Evaluation response truncated or empty. Using default value 0.5")
return 0.5

try:
score = float(response.choices[0].message.content.strip())
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1
Expand Down
77 changes: 59 additions & 18 deletions optillm/moa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,26 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
}

response = client.chat.completions.create(**provider_request)

# Convert response to dict for logging
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response

# Log provider call if conversation logging is enabled
if request_id:
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

completions = [choice.message.content for choice in response.choices]

# Check for valid response with None-checking
if response is None or not response.choices:
raise Exception("Response is None or has no choices")

completions = [choice.message.content for choice in response.choices if choice.message.content is not None]
moa_completion_tokens += response.usage.completion_tokens
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")


# Check if any valid completions were generated
if not completions:
raise Exception("No valid completions generated (all were None)")

except Exception as e:
logger.warning(f"n parameter not supported by provider: {str(e)}")
logger.info("Falling back to generating 3 completions one by one")
Expand All @@ -56,14 +64,22 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
}

response = client.chat.completions.create(**provider_request)

# Convert response to dict for logging
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response

# Log provider call if conversation logging is enabled
if request_id:
conversation_logger.log_provider_call(request_id, provider_request, response_dict)


# Check for valid response with None-checking
if (response is None or
not response.choices or
response.choices[0].message.content is None or
response.choices[0].finish_reason == "length"):
logger.warning(f"Completion {i+1}/3 truncated or empty, skipping")
continue

completions.append(response.choices[0].message.content)
moa_completion_tokens += response.usage.completion_tokens
logger.debug(f"Generated completion {i+1}/3")
Expand All @@ -77,7 +93,12 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
return "Error: Could not generate any completions", 0

logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {moa_completion_tokens}")


# Double-check we have at least one completion
if not completions:
logger.error("No completions available for processing")
return "Error: Could not generate any completions", moa_completion_tokens

# Handle case where fewer than 3 completions were generated
if len(completions) < 3:
original_count = len(completions)
Expand Down Expand Up @@ -118,15 +139,24 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
}

critique_response = client.chat.completions.create(**provider_request)

# Convert response to dict for logging
response_dict = critique_response.model_dump() if hasattr(critique_response, 'model_dump') else critique_response

# Log provider call if conversation logging is enabled
if request_id:
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

critiques = critique_response.choices[0].message.content

# Check for valid response with None-checking
if (critique_response is None or
not critique_response.choices or
critique_response.choices[0].message.content is None or
critique_response.choices[0].finish_reason == "length"):
logger.warning("Critique response truncated or empty, using generic critique")
critiques = "All candidates show reasonable approaches to the problem."
else:
critiques = critique_response.choices[0].message.content

moa_completion_tokens += critique_response.usage.completion_tokens
logger.info(f"Generated critiques. Tokens used: {critique_response.usage.completion_tokens}")

Expand Down Expand Up @@ -165,16 +195,27 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
}

final_response = client.chat.completions.create(**provider_request)

# Convert response to dict for logging
response_dict = final_response.model_dump() if hasattr(final_response, 'model_dump') else final_response

# Log provider call if conversation logging is enabled
if request_id:
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

moa_completion_tokens += final_response.usage.completion_tokens
logger.info(f"Generated final response. Tokens used: {final_response.usage.completion_tokens}")


# Check for valid response with None-checking
if (final_response is None or
not final_response.choices or
final_response.choices[0].message.content is None or
final_response.choices[0].finish_reason == "length"):
logger.error("Final response truncated or empty. Consider increasing max_tokens.")
# Return best completion if final response failed
result = completions[0] if completions else "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens."
else:
result = final_response.choices[0].message.content

logger.info(f"Total completion tokens used: {moa_completion_tokens}")
return final_response.choices[0].message.content, moa_completion_tokens
return result, moa_completion_tokens
Loading